2023-09-11 20:47:19 -06:00
|
|
|
"""
|
|
|
|
This file is used by the worker that processes requests.
|
|
|
|
"""
|
|
|
|
import json
|
|
|
|
import time
|
|
|
|
from uuid import uuid4
|
|
|
|
|
|
|
|
import requests
|
|
|
|
|
|
|
|
from llm_server import opts
|
|
|
|
from llm_server.database import tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
# TODO: make the VLMM backend return TPS and time elapsed
|
|
|
|
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
|
|
|
|
|
|
|
|
def prepare_json(json_data: dict):
|
|
|
|
# logit_bias is not currently supported
|
2023-09-12 01:04:11 -06:00
|
|
|
# del json_data['logit_bias']
|
2023-09-11 20:47:19 -06:00
|
|
|
return json_data
|
|
|
|
|
|
|
|
|
|
|
|
def transform_to_text(json_request, api_response):
|
|
|
|
"""
|
|
|
|
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
|
|
|
|
:param json_request:
|
|
|
|
:param api_response:
|
|
|
|
:return:
|
|
|
|
"""
|
|
|
|
prompt = transform_prompt_to_text(json_request['messages'])
|
|
|
|
text = ''
|
|
|
|
finish_reason = None
|
|
|
|
for line in api_response.split('\n'):
|
|
|
|
if line.startswith('data:'):
|
|
|
|
try:
|
|
|
|
data = json.loads(line[5:].strip())
|
|
|
|
except json.decoder.JSONDecodeError:
|
|
|
|
break
|
|
|
|
if 'choices' in data:
|
|
|
|
for choice in data['choices']:
|
|
|
|
if 'delta' in choice and 'content' in choice['delta']:
|
|
|
|
text += choice['delta']['content']
|
|
|
|
if data['choices'][0]['finish_reason']:
|
|
|
|
finish_reason = data['choices'][0]['finish_reason']
|
|
|
|
|
|
|
|
prompt_tokens = len(tokenizer.encode(prompt))
|
|
|
|
completion_tokens = len(tokenizer.encode(text))
|
|
|
|
|
|
|
|
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
|
|
|
|
return {
|
|
|
|
"id": str(uuid4()),
|
|
|
|
"object": "chat.completion",
|
|
|
|
"created": int(time.time()),
|
|
|
|
"model": opts.running_model,
|
|
|
|
"usage": {
|
|
|
|
"prompt_tokens": prompt_tokens,
|
|
|
|
"completion_tokens": completion_tokens,
|
|
|
|
"total_tokens": prompt_tokens + completion_tokens
|
|
|
|
},
|
|
|
|
"choices": [
|
|
|
|
{
|
|
|
|
"message": {
|
|
|
|
"role": "assistant",
|
|
|
|
"content": text
|
|
|
|
},
|
|
|
|
"finish_reason": finish_reason,
|
|
|
|
"index": 0
|
|
|
|
}
|
|
|
|
]
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def transform_prompt_to_text(prompt: list):
|
|
|
|
text = ''
|
|
|
|
for item in prompt:
|
|
|
|
text += item['content'] + '\n'
|
|
|
|
return text.strip('\n')
|
|
|
|
|
|
|
|
|
|
|
|
def handle_blocking_request(json_data: dict):
|
|
|
|
try:
|
2023-09-14 01:32:49 -06:00
|
|
|
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=120)
|
2023-09-11 20:47:19 -06:00
|
|
|
except Exception as e:
|
|
|
|
return False, None, f'{e.__class__.__name__}: {e}'
|
|
|
|
|
|
|
|
# TODO: check for error here?
|
2023-09-12 01:04:11 -06:00
|
|
|
# response_json = r.json()
|
|
|
|
# response_json['error'] = False
|
|
|
|
|
|
|
|
# new_response = Response()
|
|
|
|
# new_response.status_code = r.status_code
|
|
|
|
# new_response._content = json.dumps(response_json).encode('utf-8')
|
|
|
|
# new_response.raw = io.BytesIO(new_response._content)
|
|
|
|
# new_response.headers = r.headers
|
|
|
|
# new_response.url = r.url
|
|
|
|
# new_response.reason = r.reason
|
|
|
|
# new_response.cookies = r.cookies
|
|
|
|
# new_response.elapsed = r.elapsed
|
|
|
|
# new_response.request = r.request
|
|
|
|
|
|
|
|
return True, r, None
|
2023-09-11 20:47:19 -06:00
|
|
|
|
|
|
|
|
|
|
|
def generate(json_data: dict):
|
|
|
|
if json_data.get('stream'):
|
|
|
|
raise Exception('streaming not implemented')
|
|
|
|
else:
|
|
|
|
return handle_blocking_request(json_data)
|