local-llm-server/llm_server/llm/vllm/vllm_backend.py

103 lines
5.2 KiB
Python

from typing import Tuple, Union
from flask import jsonify
from vllm import SamplingParams
from llm_server import opts
from llm_server.database import log_prompt
from llm_server.llm.llm_backend import LLMBackend
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
class VLLMBackend(LLMBackend):
default_params = vars(SamplingParams())
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt: str, elapsed_time, parameters, headers):
response_valid_json, response_json_body = validate_json(response)
backend_err = False
try:
response_status_code = response.status_code
except:
response_status_code = 0
if response_valid_json:
if len(response_json_body.get('text', [])):
# Does vllm return the prompt and the response together???
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
else:
# Failsafe
backend_response = ''
# TODO: how to detect an error?
# if backend_response == '':
# backend_err = True
# backend_response = format_sillytavern_err(
# f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
# f'HTTP CODE {response_status_code}'
# )
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time if not backend_err else None, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({'results': [{'text': backend_response}]}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 200
# def validate_params(self, params_dict: dict):
# self.default_params = SamplingParams()
# try:
# sampling_params = SamplingParams(
# temperature=params_dict.get('temperature', self.default_paramstemperature),
# top_p=params_dict.get('top_p', self.default_paramstop_p),
# top_k=params_dict.get('top_k', self.default_paramstop_k),
# use_beam_search=True if params_dict['num_beams'] > 1 else False,
# length_penalty=params_dict.get('length_penalty', self.default_paramslength_penalty),
# early_stopping=params_dict.get('early_stopping', self.default_paramsearly_stopping),
# stop=params_dict.get('stopping_strings', self.default_paramsstop),
# ignore_eos=params_dict.get('ban_eos_token', False),
# max_tokens=params_dict.get('max_new_tokens', self.default_paramsmax_tokens)
# )
# except ValueError as e:
# print(e)
# return False, e
# return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
# backend_response = requests.get(f'{opts.backend_url}/api/v1/models', timeout=3, verify=opts.verify_ssl)
# r_json = backend_response.json()
# model_path = Path(r_json['data'][0]['root']).name
# r_json['data'][0]['root'] = model_path
# return r_json, None
# except Exception as e:
# return False, e
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
try:
sampling_params = SamplingParams(
temperature=parameters.get('temperature', self.default_params['temperature']),
top_p=parameters.get('top_p', self.default_params['top_p']),
top_k=parameters.get('top_k', self.default_params['top_k']),
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=parameters.get('stopping_strings', self.default_params['stop']),
ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self.default_params['max_tokens'])
)
except ValueError as e:
return None, str(e).strip('.')
return vars(sampling_params), None
def validate_request(self, parameters) -> (bool, Union[str, None]):
if parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
return True, None