local-llm-server/llm_server/llm/vllm/generate.py

102 lines
3.2 KiB
Python

"""
This file is used by the worker that processes requests.
"""
import json
import time
import traceback
from uuid import uuid4
import requests
import llm_server
from llm_server import opts
# TODO: make the VLMM backend return TPS and time elapsed
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
def prepare_json(json_data: dict):
# logit_bias is not currently supported
# del json_data['logit_bias']
# Convert back to VLLM.
json_data['max_tokens'] = json_data.pop('max_new_tokens')
return json_data
def transform_to_text(json_request, api_response):
"""
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
:param json_request:
:param api_response:
:return:
"""
prompt = transform_prompt_to_text(json_request['messages'])
text = ''
finish_reason = None
for line in api_response.split('\n'):
if line.startswith('data:'):
try:
data = json.loads(line[5:].strip())
except json.decoder.JSONDecodeError:
break
if 'choices' in data:
for choice in data['choices']:
if 'delta' in choice and 'content' in choice['delta']:
text += choice['delta']['content']
if data['choices'][0]['finish_reason']:
finish_reason = data['choices'][0]['finish_reason']
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text))
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
return {
"id": str(uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": opts.running_model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
},
"choices": [
{
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason,
"index": 0
}
]
}
def transform_prompt_to_text(prompt: list):
text = ''
for item in prompt:
text += item['content'] + '\n'
return text.strip('\n')
def handle_blocking_request(json_data: dict):
try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
except requests.exceptions.ReadTimeout:
return False, None, 'Request to backend timed out'
except Exception as e:
traceback.print_exc()
return False, None, 'Request to backend encountered error'
if r.status_code != 200:
return False, r, f'Backend returned {r.status_code}'
return True, r, None
def generate(json_data: dict):
if json_data.get('stream'):
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
else:
return handle_blocking_request(json_data)