110 lines
5.4 KiB
Python
110 lines
5.4 KiB
Python
from typing import Tuple
|
|
|
|
from flask import jsonify
|
|
from vllm import SamplingParams
|
|
|
|
from llm_server.database import log_prompt
|
|
from llm_server.llm.llm_backend import LLMBackend
|
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
|
from llm_server.routes.helpers.http import validate_json
|
|
|
|
|
|
# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
|
|
|
|
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
|
|
|
|
class VLLMBackend(LLMBackend):
|
|
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
|
response_valid_json, response_json_body = validate_json(response)
|
|
backend_err = False
|
|
try:
|
|
response_status_code = response.status_code
|
|
except:
|
|
response_status_code = 0
|
|
|
|
if response_valid_json:
|
|
if len(response_json_body.get('text', [])):
|
|
# Does vllm return the prompt and the response together???
|
|
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
|
|
else:
|
|
# Failsafe
|
|
backend_response = ''
|
|
|
|
# TODO: how to detect an error?
|
|
# if backend_response == '':
|
|
# backend_err = True
|
|
# backend_response = format_sillytavern_err(
|
|
# f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
|
|
# f'HTTP CODE {response_status_code}'
|
|
# )
|
|
|
|
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
|
return jsonify({'results': [{'text': backend_response}]}), 200
|
|
else:
|
|
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
|
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True)
|
|
return jsonify({
|
|
'code': 500,
|
|
'msg': 'the backend did not return valid JSON',
|
|
'results': [{'text': backend_response}]
|
|
}), 200
|
|
|
|
# def validate_params(self, params_dict: dict):
|
|
# default_params = SamplingParams()
|
|
# try:
|
|
# sampling_params = SamplingParams(
|
|
# temperature=params_dict.get('temperature', default_params.temperature),
|
|
# top_p=params_dict.get('top_p', default_params.top_p),
|
|
# top_k=params_dict.get('top_k', default_params.top_k),
|
|
# use_beam_search=True if params_dict['num_beams'] > 1 else False,
|
|
# length_penalty=params_dict.get('length_penalty', default_params.length_penalty),
|
|
# early_stopping=params_dict.get('early_stopping', default_params.early_stopping),
|
|
# stop=params_dict.get('stopping_strings', default_params.stop),
|
|
# ignore_eos=params_dict.get('ban_eos_token', False),
|
|
# max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens)
|
|
# )
|
|
# except ValueError as e:
|
|
# print(e)
|
|
# return False, e
|
|
# return True, None
|
|
|
|
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
|
# try:
|
|
# backend_response = requests.get(f'{opts.backend_url}/api/v1/models', timeout=3, verify=opts.verify_ssl)
|
|
# r_json = backend_response.json()
|
|
# model_path = Path(r_json['data'][0]['root']).name
|
|
# r_json['data'][0]['root'] = model_path
|
|
# return r_json, None
|
|
# except Exception as e:
|
|
# return False, e
|
|
|
|
def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]:
|
|
default_params = SamplingParams()
|
|
try:
|
|
sampling_params = SamplingParams(
|
|
temperature=parameters.get('temperature', default_params.temperature),
|
|
top_p=parameters.get('top_p', default_params.top_p),
|
|
top_k=parameters.get('top_k', default_params.top_k),
|
|
use_beam_search=True if parameters['num_beams'] > 1 else False,
|
|
stop=parameters.get('stopping_strings', default_params.stop),
|
|
ignore_eos=parameters.get('ban_eos_token', False),
|
|
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
|
|
)
|
|
except ValueError as e:
|
|
print(e)
|
|
return None, e
|
|
return vars(sampling_params), None
|
|
|
|
# def transform_sampling_params(params: SamplingParams):
|
|
# return {
|
|
# 'temperature': params['temperature'],
|
|
# 'top_p': params['top_p'],
|
|
# 'top_k': params['top_k'],
|
|
# 'use_beam_search' = True if parameters['num_beams'] > 1 else False,
|
|
# length_penalty = parameters.get('length_penalty', default_params.length_penalty),
|
|
# early_stopping = parameters.get('early_stopping', default_params.early_stopping),
|
|
# stop = parameters.get('stopping_strings', default_params.stop),
|
|
# ignore_eos = parameters.get('ban_eos_token', False),
|
|
# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens)
|
|
# }
|