103 lines
5.2 KiB
Python
103 lines
5.2 KiB
Python
from typing import Tuple, Union
|
|
|
|
from flask import jsonify
|
|
from vllm import SamplingParams
|
|
|
|
from llm_server import opts
|
|
from llm_server.database import log_prompt
|
|
from llm_server.llm.llm_backend import LLMBackend
|
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
|
from llm_server.routes.helpers.http import validate_json
|
|
|
|
|
|
# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
|
|
|
|
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
|
|
|
|
class VLLMBackend(LLMBackend):
|
|
default_params = vars(SamplingParams())
|
|
|
|
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt: str, elapsed_time, parameters, headers):
|
|
response_valid_json, response_json_body = validate_json(response)
|
|
backend_err = False
|
|
try:
|
|
response_status_code = response.status_code
|
|
except:
|
|
response_status_code = 0
|
|
|
|
if response_valid_json:
|
|
if len(response_json_body.get('text', [])):
|
|
# Does vllm return the prompt and the response together???
|
|
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
|
|
else:
|
|
# Failsafe
|
|
backend_response = ''
|
|
|
|
# TODO: how to detect an error?
|
|
# if backend_response == '':
|
|
# backend_err = True
|
|
# backend_response = format_sillytavern_err(
|
|
# f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
|
|
# f'HTTP CODE {response_status_code}'
|
|
# )
|
|
|
|
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time if not backend_err else None, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
|
return jsonify({'results': [{'text': backend_response}]}), 200
|
|
else:
|
|
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
|
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, request.url, is_error=True)
|
|
return jsonify({
|
|
'code': 500,
|
|
'msg': 'the backend did not return valid JSON',
|
|
'results': [{'text': backend_response}]
|
|
}), 200
|
|
|
|
# def validate_params(self, params_dict: dict):
|
|
# self.default_params = SamplingParams()
|
|
# try:
|
|
# sampling_params = SamplingParams(
|
|
# temperature=params_dict.get('temperature', self.default_paramstemperature),
|
|
# top_p=params_dict.get('top_p', self.default_paramstop_p),
|
|
# top_k=params_dict.get('top_k', self.default_paramstop_k),
|
|
# use_beam_search=True if params_dict['num_beams'] > 1 else False,
|
|
# length_penalty=params_dict.get('length_penalty', self.default_paramslength_penalty),
|
|
# early_stopping=params_dict.get('early_stopping', self.default_paramsearly_stopping),
|
|
# stop=params_dict.get('stopping_strings', self.default_paramsstop),
|
|
# ignore_eos=params_dict.get('ban_eos_token', False),
|
|
# max_tokens=params_dict.get('max_new_tokens', self.default_paramsmax_tokens)
|
|
# )
|
|
# except ValueError as e:
|
|
# print(e)
|
|
# return False, e
|
|
# return True, None
|
|
|
|
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
|
# try:
|
|
# backend_response = requests.get(f'{opts.backend_url}/api/v1/models', timeout=3, verify=opts.verify_ssl)
|
|
# r_json = backend_response.json()
|
|
# model_path = Path(r_json['data'][0]['root']).name
|
|
# r_json['data'][0]['root'] = model_path
|
|
# return r_json, None
|
|
# except Exception as e:
|
|
# return False, e
|
|
|
|
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
|
try:
|
|
sampling_params = SamplingParams(
|
|
temperature=parameters.get('temperature', self.default_params['temperature']),
|
|
top_p=parameters.get('top_p', self.default_params['top_p']),
|
|
top_k=parameters.get('top_k', self.default_params['top_k']),
|
|
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
|
stop=parameters.get('stopping_strings', self.default_params['stop']),
|
|
ignore_eos=parameters.get('ban_eos_token', False),
|
|
max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens'])
|
|
)
|
|
except ValueError as e:
|
|
return None, str(e).strip('.')
|
|
return vars(sampling_params), None
|
|
|
|
def validate_request(self, parameters) -> (bool, Union[str, None]):
|
|
if parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
|
|
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
|
|
return True, None
|