local-llm-server/llm_server/llm/vllm/vllm_backend.py

110 lines
5.4 KiB
Python

from typing import Tuple
from flask import jsonify
from vllm import SamplingParams
from llm_server.database import log_prompt
from llm_server.llm.llm_backend import LLMBackend
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
class VLLMBackend(LLMBackend):
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
response_valid_json, response_json_body = validate_json(response)
backend_err = False
try:
response_status_code = response.status_code
except:
response_status_code = 0
if response_valid_json:
if len(response_json_body.get('text', [])):
# Does vllm return the prompt and the response together???
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
else:
# Failsafe
backend_response = ''
# TODO: how to detect an error?
# if backend_response == '':
# backend_err = True
# backend_response = format_sillytavern_err(
# f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
# f'HTTP CODE {response_status_code}'
# )
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({'results': [{'text': backend_response}]}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 200
# def validate_params(self, params_dict: dict):
# default_params = SamplingParams()
# try:
# sampling_params = SamplingParams(
# temperature=params_dict.get('temperature', default_params.temperature),
# top_p=params_dict.get('top_p', default_params.top_p),
# top_k=params_dict.get('top_k', default_params.top_k),
# use_beam_search=True if params_dict['num_beams'] > 1 else False,
# length_penalty=params_dict.get('length_penalty', default_params.length_penalty),
# early_stopping=params_dict.get('early_stopping', default_params.early_stopping),
# stop=params_dict.get('stopping_strings', default_params.stop),
# ignore_eos=params_dict.get('ban_eos_token', False),
# max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens)
# )
# except ValueError as e:
# print(e)
# return False, e
# return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
# backend_response = requests.get(f'{opts.backend_url}/api/v1/models', timeout=3, verify=opts.verify_ssl)
# r_json = backend_response.json()
# model_path = Path(r_json['data'][0]['root']).name
# r_json['data'][0]['root'] = model_path
# return r_json, None
# except Exception as e:
# return False, e
def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]:
default_params = SamplingParams()
try:
sampling_params = SamplingParams(
temperature=parameters.get('temperature', default_params.temperature),
top_p=parameters.get('top_p', default_params.top_p),
top_k=parameters.get('top_k', default_params.top_k),
use_beam_search=True if parameters['num_beams'] > 1 else False,
stop=parameters.get('stopping_strings', default_params.stop),
ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
)
except ValueError as e:
print(e)
return None, e
return vars(sampling_params), None
# def transform_sampling_params(params: SamplingParams):
# return {
# 'temperature': params['temperature'],
# 'top_p': params['top_p'],
# 'top_k': params['top_k'],
# 'use_beam_search' = True if parameters['num_beams'] > 1 else False,
# length_penalty = parameters.get('length_penalty', default_params.length_penalty),
# early_stopping = parameters.get('early_stopping', default_params.early_stopping),
# stop = parameters.get('stopping_strings', default_params.stop),
# ignore_eos = parameters.get('ban_eos_token', False),
# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens)
# }