local-llm-server/llm_server/llm/vllm/vllm_backend.py

67 lines
3.2 KiB
Python
Raw Normal View History

import traceback
2023-09-12 16:40:09 -06:00
from typing import Tuple, Union
import requests
2023-09-11 20:47:19 -06:00
from flask import jsonify
from vllm import SamplingParams
import llm_server
2023-09-12 16:40:09 -06:00
from llm_server import opts
from llm_server.database.database import log_prompt
2023-09-11 20:47:19 -06:00
from llm_server.llm.llm_backend import LLMBackend
class VLLMBackend(LLMBackend):
_default_params = vars(SamplingParams())
2023-09-12 16:40:09 -06:00
def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers):
if len(response_json_body.get('text', [])):
# Does vllm return the prompt and the response together???
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
2023-09-11 20:47:19 -06:00
else:
# Failsafe
backend_response = ''
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
return jsonify({'results': [{'text': backend_response}]}), 200
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
try:
# top_k == -1 means disabled
top_k = parameters.get('top_k', self._default_params['top_k'])
if top_k <= 0:
top_k = -1
sampling_params = SamplingParams(
temperature=parameters.get('temperature', self._default_params['temperature']),
top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=top_k,
2023-09-12 16:40:09 -06:00
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=parameters.get('stopping_strings', self._default_params['stop']),
ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
)
except ValueError as e:
return None, str(e).strip('.')
return vars(sampling_params), None
2023-09-12 16:40:09 -06:00
def validate_request(self, parameters) -> (bool, Union[str, None]):
if parameters.get('max_new_tokens', 0) > opts.max_new_tokens:
return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}'
return True, None
# def tokenize(self, prompt):
# try:
# r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
# j = r.json()
# return j['length']
# except:
# # Fall back to whatever the superclass is doing.
# print(traceback.format_exc())
# return super().tokenize(prompt)
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = llm_server.llm.tokenizer(prompt)
if prompt_len > opts.context_size:
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}). Please lower your context size'
return True, None