2023-09-25 12:30:40 -06:00
|
|
|
import json
|
|
|
|
import threading
|
|
|
|
import time
|
2023-09-14 14:36:22 -06:00
|
|
|
import traceback
|
|
|
|
|
2023-09-25 12:30:40 -06:00
|
|
|
from flask import Response, jsonify, request
|
2023-09-12 16:40:09 -06:00
|
|
|
|
|
|
|
from . import openai_bp
|
2023-09-26 13:32:33 -06:00
|
|
|
from ..cache import redis
|
2023-09-12 16:40:09 -06:00
|
|
|
from ..helpers.http import validate_json
|
2023-09-26 22:09:11 -06:00
|
|
|
from ..openai_request_handler import OpenAIRequestHandler
|
2023-09-25 12:30:40 -06:00
|
|
|
from ... import opts
|
|
|
|
from ...database.database import log_prompt
|
|
|
|
from ...llm.generator import generator
|
2023-09-27 14:48:47 -06:00
|
|
|
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
|
2023-09-25 12:30:40 -06:00
|
|
|
from ...llm.vllm import tokenize
|
2023-09-12 16:40:09 -06:00
|
|
|
|
|
|
|
|
2023-09-24 21:45:30 -06:00
|
|
|
# TODO: add rate-limit headers?
|
|
|
|
|
2023-09-12 16:40:09 -06:00
|
|
|
@openai_bp.route('/chat/completions', methods=['POST'])
|
|
|
|
def openai_chat_completions():
|
|
|
|
request_valid_json, request_json_body = validate_json(request)
|
2023-09-25 09:32:23 -06:00
|
|
|
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
2023-09-12 16:40:09 -06:00
|
|
|
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
|
|
|
else:
|
2023-09-25 12:30:40 -06:00
|
|
|
handler = OpenAIRequestHandler(request, request_json_body)
|
|
|
|
if request_json_body.get('stream'):
|
|
|
|
if not opts.enable_streaming:
|
|
|
|
# TODO: return a proper OAI error message
|
|
|
|
return 'disabled', 401
|
|
|
|
|
|
|
|
if opts.mode != 'vllm':
|
|
|
|
# TODO: implement other backends
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
response_status_code = 0
|
|
|
|
start_time = time.time()
|
|
|
|
request_valid, invalid_response = handler.validate_request()
|
|
|
|
if not request_valid:
|
|
|
|
# TODO: simulate OAI here
|
2023-09-25 23:14:35 -06:00
|
|
|
raise Exception('TODO: simulate OAI here')
|
2023-09-25 12:30:40 -06:00
|
|
|
else:
|
2023-09-26 23:59:22 -06:00
|
|
|
handler.prompt = transform_messages_to_prompt(request_json_body['messages'])
|
2023-09-25 12:30:40 -06:00
|
|
|
msg_to_backend = {
|
|
|
|
**handler.parameters,
|
|
|
|
'prompt': handler.prompt,
|
|
|
|
'stream': True,
|
|
|
|
}
|
|
|
|
try:
|
|
|
|
response = generator(msg_to_backend)
|
|
|
|
r_headers = dict(request.headers)
|
|
|
|
r_url = request.url
|
2023-09-26 13:32:33 -06:00
|
|
|
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
|
2023-09-25 22:01:57 -06:00
|
|
|
oai_string = generate_oai_string(30)
|
2023-09-25 12:30:40 -06:00
|
|
|
|
|
|
|
def generate():
|
|
|
|
generated_text = ''
|
|
|
|
partial_response = b''
|
|
|
|
for chunk in response.iter_content(chunk_size=1):
|
|
|
|
partial_response += chunk
|
|
|
|
if partial_response.endswith(b'\x00'):
|
|
|
|
json_strs = partial_response.split(b'\x00')
|
|
|
|
for json_str in json_strs:
|
|
|
|
if json_str:
|
|
|
|
try:
|
|
|
|
json_obj = json.loads(json_str.decode())
|
|
|
|
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
|
|
|
generated_text = generated_text + new
|
|
|
|
except IndexError:
|
2023-09-25 12:38:02 -06:00
|
|
|
# ????
|
2023-09-25 12:30:40 -06:00
|
|
|
continue
|
|
|
|
|
2023-09-25 12:38:02 -06:00
|
|
|
data = {
|
2023-09-25 22:01:57 -06:00
|
|
|
"id": f"chatcmpl-{oai_string}",
|
2023-09-25 12:38:02 -06:00
|
|
|
"object": "chat.completion.chunk",
|
|
|
|
"created": int(time.time()),
|
|
|
|
"model": model,
|
|
|
|
"choices": [
|
|
|
|
{
|
|
|
|
"index": 0,
|
|
|
|
"delta": {
|
|
|
|
"content": new
|
|
|
|
},
|
|
|
|
"finish_reason": None
|
|
|
|
}
|
|
|
|
]
|
|
|
|
}
|
|
|
|
yield f'data: {json.dumps(data)}\n\n'
|
|
|
|
|
2023-09-25 12:30:40 -06:00
|
|
|
yield 'data: [DONE]\n\n'
|
|
|
|
end_time = time.time()
|
|
|
|
elapsed_time = end_time - start_time
|
|
|
|
|
|
|
|
def background_task():
|
|
|
|
generated_tokens = tokenize(generated_text)
|
|
|
|
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
|
|
|
|
|
|
|
|
# TODO: use async/await instead of threads
|
2023-09-25 18:18:29 -06:00
|
|
|
thread = threading.Thread(target=background_task)
|
|
|
|
thread.start()
|
|
|
|
thread.join()
|
2023-09-25 12:30:40 -06:00
|
|
|
|
|
|
|
return Response(generate(), mimetype='text/event-stream')
|
|
|
|
except:
|
|
|
|
# TODO: simulate OAI here
|
|
|
|
raise Exception
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
return handler.handle_request()
|
2023-09-27 14:48:47 -06:00
|
|
|
except Exception:
|
2023-09-25 12:30:40 -06:00
|
|
|
traceback.print_exc()
|
2023-09-27 14:48:47 -06:00
|
|
|
return 'Internal server error', 500
|