local-llm-server/llm_server/routes/openai/chat_completions.py

110 lines
5.1 KiB
Python
Raw Normal View History

import json
import threading
import time
2023-09-14 14:36:22 -06:00
import traceback
from flask import Response, jsonify, request
2023-09-12 16:40:09 -06:00
from . import openai_bp
2023-09-14 14:36:22 -06:00
from ..helpers.client import format_sillytavern_err
2023-09-12 16:40:09 -06:00
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string
from ... import opts
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.vllm import tokenize
2023-09-12 16:40:09 -06:00
# TODO: add rate-limit headers?
2023-09-12 16:40:09 -06:00
@openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions():
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
2023-09-12 16:40:09 -06:00
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else:
handler = OpenAIRequestHandler(request, request_json_body)
if request_json_body.get('stream'):
if not opts.enable_streaming:
# TODO: return a proper OAI error message
return 'disabled', 401
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception
else:
handler.prompt = handler.transform_messages_to_prompt()
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
try:
response = generator(msg_to_backend)
r_headers = dict(request.headers)
r_url = request.url
model = opts.running_model if opts.openai_epose_our_model else request_json_body.get('model')
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
data = {
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
except IndexError:
continue
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
threading.Thread(target=background_task).start()
return Response(generate(), mimetype='text/event-stream')
except:
# TODO: simulate OAI here
raise Exception
else:
try:
return handler.handle_request()
except Exception as e:
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
traceback.print_exc()
print(request.data)
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500