local-llm-server/llm_server/llm/vllm/generate.py

104 lines
3.3 KiB
Python
Raw Normal View History

2023-09-11 20:47:19 -06:00
"""
This file is used by the worker that processes requests.
"""
import json
import time
import traceback
2023-09-11 20:47:19 -06:00
from uuid import uuid4
import requests
import llm_server
2023-09-11 20:47:19 -06:00
from llm_server import opts
2023-09-26 13:32:33 -06:00
from llm_server.routes.cache import redis
2023-09-11 20:47:19 -06:00
# TODO: make the VLMM backend return TPS and time elapsed
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
def prepare_json(json_data: dict):
# logit_bias is not currently supported
# del json_data['logit_bias']
# Convert back to VLLM.
json_data['max_tokens'] = json_data.pop('max_new_tokens')
2023-09-11 20:47:19 -06:00
return json_data
def transform_to_text(json_request, api_response):
"""
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
:param json_request:
:param api_response:
:return:
"""
prompt = transform_prompt_to_text(json_request['messages'])
text = ''
finish_reason = None
for line in api_response.split('\n'):
if line.startswith('data:'):
try:
data = json.loads(line[5:].strip())
except json.decoder.JSONDecodeError:
break
if 'choices' in data:
for choice in data['choices']:
if 'delta' in choice and 'content' in choice['delta']:
text += choice['delta']['content']
if data['choices'][0]['finish_reason']:
finish_reason = data['choices'][0]['finish_reason']
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text))
2023-09-26 13:32:33 -06:00
running_model = redis.get('running_model', str, 'ERROR')
2023-09-11 20:47:19 -06:00
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
return {
"id": str(uuid4()),
"object": "chat.completion",
"created": int(time.time()),
2023-09-26 13:32:33 -06:00
"model": running_model,
2023-09-11 20:47:19 -06:00
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
},
"choices": [
{
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason,
"index": 0
}
]
}
def transform_prompt_to_text(prompt: list):
text = ''
for item in prompt:
text += item['content'] + '\n'
return text.strip('\n')
def handle_blocking_request(json_data: dict):
try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
except requests.exceptions.ReadTimeout:
return False, None, 'Request to backend timed out'
2023-09-11 20:47:19 -06:00
except Exception as e:
traceback.print_exc()
return False, None, 'Request to backend encountered error'
if r.status_code != 200:
return False, r, f'Backend returned {r.status_code}'
return True, r, None
2023-09-11 20:47:19 -06:00
def generate(json_data: dict):
if json_data.get('stream'):
2023-09-23 17:57:23 -06:00
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
2023-09-11 20:47:19 -06:00
else:
return handle_blocking_request(json_data)