2023-09-12 01:04:11 -06:00
from typing import Tuple
2023-09-11 20:47:19 -06:00
from flask import jsonify
from vllm import SamplingParams
from llm_server . database import log_prompt
from llm_server . llm . llm_backend import LLMBackend
from llm_server . routes . helpers . client import format_sillytavern_err
from llm_server . routes . helpers . http import validate_json
# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
class VLLMBackend ( LLMBackend ) :
def handle_response ( self , success , response , error_msg , client_ip , token , prompt , elapsed_time , parameters , headers ) :
response_valid_json , response_json_body = validate_json ( response )
backend_err = False
try :
response_status_code = response . status_code
except :
response_status_code = 0
if response_valid_json :
2023-09-12 01:04:11 -06:00
if len ( response_json_body . get ( ' text ' , [ ] ) ) :
# Does vllm return the prompt and the response together???
backend_response = response_json_body [ ' text ' ] [ 0 ] . split ( prompt ) [ 1 ] . strip ( ' ' ) . strip ( ' \n ' )
else :
# Failsafe
backend_response = ' '
# TODO: how to detect an error?
# if backend_response == '':
# backend_err = True
# backend_response = format_sillytavern_err(
# f'Backend (vllm-gptq) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
# f'HTTP CODE {response_status_code}'
# )
log_prompt ( client_ip , token , prompt , backend_response , elapsed_time if not backend_err else None , parameters , headers , response_status_code , response_json_body . get ( ' details ' , { } ) . get ( ' generated_tokens ' ) , is_error = backend_err )
return jsonify ( { ' results ' : [ { ' text ' : backend_response } ] } ) , 200
2023-09-11 20:47:19 -06:00
else :
backend_response = format_sillytavern_err ( f ' The backend did not return valid JSON. ' , ' error ' )
log_prompt ( client_ip , token , prompt , backend_response , elapsed_time , parameters , headers , response . status_code if response else None , is_error = True )
return jsonify ( {
' code ' : 500 ,
' msg ' : ' the backend did not return valid JSON ' ,
' results ' : [ { ' text ' : backend_response } ]
} ) , 200
2023-09-12 01:04:11 -06:00
# def validate_params(self, params_dict: dict):
# default_params = SamplingParams()
# try:
# sampling_params = SamplingParams(
# temperature=params_dict.get('temperature', default_params.temperature),
# top_p=params_dict.get('top_p', default_params.top_p),
# top_k=params_dict.get('top_k', default_params.top_k),
# use_beam_search=True if params_dict['num_beams'] > 1 else False,
# length_penalty=params_dict.get('length_penalty', default_params.length_penalty),
# early_stopping=params_dict.get('early_stopping', default_params.early_stopping),
# stop=params_dict.get('stopping_strings', default_params.stop),
# ignore_eos=params_dict.get('ban_eos_token', False),
# max_tokens=params_dict.get('max_new_tokens', default_params.max_tokens)
# )
# except ValueError as e:
# print(e)
# return False, e
# return True, None
2023-09-11 20:47:19 -06:00
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
# backend_response = requests.get(f'{opts.backend_url}/api/v1/models', timeout=3, verify=opts.verify_ssl)
# r_json = backend_response.json()
# model_path = Path(r_json['data'][0]['root']).name
# r_json['data'][0]['root'] = model_path
# return r_json, None
# except Exception as e:
# return False, e
2023-09-12 01:04:11 -06:00
2023-09-12 10:30:45 -06:00
def get_parameters ( self , parameters ) - > Tuple [ dict | None , str | None ] :
2023-09-12 01:04:11 -06:00
default_params = SamplingParams ( )
try :
sampling_params = SamplingParams (
temperature = parameters . get ( ' temperature ' , default_params . temperature ) ,
top_p = parameters . get ( ' top_p ' , default_params . top_p ) ,
top_k = parameters . get ( ' top_k ' , default_params . top_k ) ,
use_beam_search = True if parameters [ ' num_beams ' ] > 1 else False ,
stop = parameters . get ( ' stopping_strings ' , default_params . stop ) ,
ignore_eos = parameters . get ( ' ban_eos_token ' , False ) ,
max_tokens = parameters . get ( ' max_new_tokens ' , default_params . max_tokens )
)
except ValueError as e :
2023-09-12 10:30:45 -06:00
return None , str ( e ) . strip ( ' . ' )
2023-09-12 01:04:11 -06:00
return vars ( sampling_params ) , None
# def transform_sampling_params(params: SamplingParams):
# return {
# 'temperature': params['temperature'],
# 'top_p': params['top_p'],
# 'top_k': params['top_k'],
# 'use_beam_search' = True if parameters['num_beams'] > 1 else False,
# length_penalty = parameters.get('length_penalty', default_params.length_penalty),
# early_stopping = parameters.get('early_stopping', default_params.early_stopping),
# stop = parameters.get('stopping_strings', default_params.stop),
# ignore_eos = parameters.get('ban_eos_token', False),
# max_tokens = parameters.get('max_new_tokens', default_params.max_tokens)
# }