local-llm-server/llm_server/llm/vllm/vllm_backend.py

64 lines
2.9 KiB
Python
Raw Normal View History

2023-09-11 20:47:19 -06:00
from flask import jsonify
from vllm import SamplingParams
from llm_server.database import log_prompt
from llm_server.helpers import indefinite_article
from llm_server.llm.llm_backend import LLMBackend
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.helpers.http import validate_json
# https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py
# TODO: https://gitgud.io/khanon/oai-reverse-proxy/-/blob/main/src/proxy/middleware/common.ts?ref_type=heads#L69
class VLLMBackend(LLMBackend):
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
response_valid_json, response_json_body = validate_json(response)
backend_err = False
try:
response_status_code = response.status_code
except:
response_status_code = 0
if response_valid_json:
backend_response = response_json_body
if response_json_body.get('error'):
backend_err = True
error_type = response_json_body.get('error_type')
error_type_string = f'returned {indefinite_article(error_type)} {error_type} error'
backend_response = format_sillytavern_err(
f'Backend (vllm) {error_type_string}: {response_json_body.get("error")}',
f'HTTP CODE {response_status_code}'
)
log_prompt(client_ip, token, prompt, backend_response['choices'][0]['message']['content'], elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify(backend_response), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code if response else None, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 200
def validate_params(self, params_dict: dict):
try:
sampling_params = SamplingParams(**params_dict)
except ValueError as e:
print(e)
return False, e
return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
# backend_response = requests.get(f'{opts.backend_url}/api/v1/models', timeout=3, verify=opts.verify_ssl)
# r_json = backend_response.json()
# model_path = Path(r_json['data'][0]['root']).name
# r_json['data'][0]['root'] = model_path
# return r_json, None
# except Exception as e:
# return False, e