58 lines
2.4 KiB
Python
58 lines
2.4 KiB
Python
import threading
|
|
from typing import Tuple
|
|
|
|
from flask import jsonify
|
|
from vllm import SamplingParams
|
|
|
|
from llm_server.database.database import log_prompt
|
|
from llm_server.llm.llm_backend import LLMBackend
|
|
|
|
|
|
class VLLMBackend(LLMBackend):
|
|
_default_params = vars(SamplingParams())
|
|
|
|
def handle_response(self, success, request, response_json_body, response_status_code, client_ip, token, prompt: str, elapsed_time, parameters, headers):
|
|
if len(response_json_body.get('text', [])):
|
|
# Does vllm return the prompt and the response together???
|
|
backend_response = response_json_body['text'][0].split(prompt)[1].strip(' ').strip('\n')
|
|
else:
|
|
# Failsafe
|
|
backend_response = ''
|
|
|
|
r_url = request.url
|
|
|
|
def background_task():
|
|
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
|
|
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
|
|
|
|
# TODO: use async/await instead of threads
|
|
thread = threading.Thread(target=background_task)
|
|
thread.start()
|
|
thread.join()
|
|
|
|
return jsonify({'results': [{'text': backend_response}]}), 200
|
|
|
|
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
|
try:
|
|
# top_k == -1 means disabled
|
|
top_k = parameters.get('top_k', self._default_params['top_k'])
|
|
if top_k <= 0:
|
|
top_k = -1
|
|
sampling_params = SamplingParams(
|
|
temperature=parameters.get('temperature', self._default_params['temperature']),
|
|
top_p=parameters.get('top_p', self._default_params['top_p']),
|
|
top_k=top_k,
|
|
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
|
stop=parameters.get('stopping_strings', self._default_params['stop']),
|
|
ignore_eos=parameters.get('ban_eos_token', False),
|
|
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
|
|
)
|
|
except ValueError as e:
|
|
return None, str(e).strip('.')
|
|
|
|
# We use max_new_tokens throughout the server.
|
|
result = vars(sampling_params)
|
|
result['max_new_tokens'] = result.pop('max_tokens')
|
|
|
|
return result, None
|