157 lines
7.1 KiB
Python
157 lines
7.1 KiB
Python
import json
|
|
import time
|
|
import traceback
|
|
|
|
from flask import Response, jsonify, request
|
|
|
|
from llm_server.custom_redis import redis
|
|
from . import openai_bp
|
|
from ..helpers.http import validate_json
|
|
from ..openai_request_handler import OpenAIRequestHandler
|
|
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
|
from ... import opts
|
|
from ...database.log_to_db import log_to_db
|
|
from ...llm.generator import generator
|
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai
|
|
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
|
|
|
|
|
# TODO: add rate-limit headers?
|
|
|
|
|
|
@openai_bp.route('/chat/completions', methods=['POST'])
|
|
def openai_chat_completions():
|
|
request_valid_json, request_json_body = validate_json(request)
|
|
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
|
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
|
else:
|
|
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body)
|
|
if not request_json_body.get('stream'):
|
|
try:
|
|
invalid_oai_err_msg = validate_oai(request_json_body)
|
|
if invalid_oai_err_msg:
|
|
return invalid_oai_err_msg
|
|
return handler.handle_request()
|
|
except Exception:
|
|
traceback.print_exc()
|
|
return 'Internal server error', 500
|
|
else:
|
|
if not opts.enable_streaming:
|
|
return
|
|
|
|
handler.parameters, _ = handler.get_parameters()
|
|
handler.request_json_body = {
|
|
'messages': handler.request_json_body['messages'],
|
|
'model': handler.request_json_body['model'],
|
|
**handler.parameters
|
|
}
|
|
|
|
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
|
if invalid_oai_err_msg:
|
|
return invalid_oai_err_msg
|
|
|
|
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
|
|
|
|
if opts.openai_silent_trim:
|
|
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
|
|
else:
|
|
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
|
|
|
|
response_status_code = 0
|
|
start_time = time.time()
|
|
|
|
request_valid, invalid_response = handler.validate_request()
|
|
if not request_valid:
|
|
return invalid_response
|
|
else:
|
|
msg_to_backend = {
|
|
**handler.parameters,
|
|
'prompt': handler.prompt,
|
|
'stream': True,
|
|
}
|
|
|
|
# Add a dummy event to the queue and wait for it to reach a worker
|
|
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
|
if not event:
|
|
log_to_db(
|
|
handler.client_ip,
|
|
handler.token,
|
|
handler.prompt,
|
|
None,
|
|
None,
|
|
handler.parameters,
|
|
request.headers,
|
|
response_status_code,
|
|
request.url,
|
|
handler.backend_url,
|
|
)
|
|
return handler.handle_ratelimited()
|
|
|
|
# Wait for a worker to get our request and discard it.
|
|
_, _, _ = event.wait()
|
|
|
|
try:
|
|
r_headers = dict(request.headers)
|
|
r_url = request.url
|
|
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
|
oai_string = generate_oai_string(30)
|
|
|
|
def generate():
|
|
response = generator(msg_to_backend, handler.backend_url)
|
|
generated_text = ''
|
|
partial_response = b''
|
|
for chunk in response.iter_content(chunk_size=1):
|
|
partial_response += chunk
|
|
if partial_response.endswith(b'\x00'):
|
|
json_strs = partial_response.split(b'\x00')
|
|
for json_str in json_strs:
|
|
if json_str:
|
|
try:
|
|
json_obj = json.loads(json_str.decode())
|
|
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
|
generated_text = generated_text + new
|
|
except IndexError:
|
|
# ????
|
|
continue
|
|
|
|
data = {
|
|
"id": f"chatcmpl-{oai_string}",
|
|
"object": "chat.completion.chunk",
|
|
"created": int(time.time()),
|
|
"model": model,
|
|
"choices": [
|
|
{
|
|
"index": 0,
|
|
"delta": {
|
|
"content": new
|
|
},
|
|
"finish_reason": None
|
|
}
|
|
]
|
|
}
|
|
yield f'data: {json.dumps(data)}\n\n'
|
|
yield 'data: [DONE]\n\n'
|
|
end_time = time.time()
|
|
elapsed_time = end_time - start_time
|
|
log_to_db(
|
|
handler.client_ip,
|
|
handler.token,
|
|
handler.prompt,
|
|
generated_text,
|
|
elapsed_time,
|
|
handler.parameters,
|
|
r_headers,
|
|
response_status_code,
|
|
r_url,
|
|
handler.backend_url,
|
|
)
|
|
|
|
return Response(generate(), mimetype='text/event-stream')
|
|
except Exception:
|
|
traceback.print_exc()
|
|
return 'INTERNAL SERVER', 500
|
|
finally:
|
|
# The worker incremented it, we'll decrement it.
|
|
decrement_ip_count(handler.client_ip, 'processing_ips')
|
|
decr_active_workers(handler.selected_model, handler.backend_url)
|