local-llm-server/llm_server/routes/openai/chat_completions.py

176 lines
8.4 KiB
Python

import json
import time
import traceback
import ujson
from flask import Response, jsonify, request
from redis import Redis
from llm_server.custom_redis import redis
from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue
from ... import opts
from ...database.log_to_db import log_to_db
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
# TODO: add rate-limit headers?
@openai_bp.route('/chat/completions', methods=['POST'])
@openai_model_bp.route('/<model_name>/v1/chat/completions', methods=['POST'])
def openai_chat_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else:
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline:
return return_invalid_model_err(model_name)
if not request_json_body.get('stream'):
try:
return handler.handle_request()
except Exception:
traceback.print_exc()
return 'Internal server error', 500
else:
if not opts.enable_streaming:
return 'Streaming disabled', 403
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
handler.parameters, e = handler.get_parameters()
handler.request_json_body = {
'messages': handler.request_json_body['messages'],
'model': handler.request_json_body['model'],
**handler.parameters
}
if opts.openai_silent_trim:
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
else:
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
if not handler.prompt:
# Prevent issues on the backend.
return 'Invalid prompt', 400
# Need to set the prompt in the JSON body since that's what the inference worker expects.
handler.request_json_body['prompt'] = handler.prompt
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
429,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
try:
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
# Need to do this before we enter generate() since we want to be able to
# return a 408 if necessary.
_, stream_name, error_msg = event.wait()
if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None # set to null so that the Finally ignores it.
return 'Request Timeout', 408
def generate():
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0'
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data'])
if data['error']:
# Not printing error since we can just check the daemon log.
print('OAI streaming encountered error')
yield 'data: [DONE]\n\n'
return
elif data['new']:
response = {
"id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": data['new']
},
"finish_reason": None
}
]
}
generated_text = generated_text + data['new']
yield f'data: {json.dumps(response)}\n\n'
elif data['completed']:
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
200,
r_url,
handler.backend_url,
)
return
except GeneratorExit:
return
except Exception:
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
if event:
redis.publish(f'notifications:{event.event_id}', 'canceled')
if stream_name:
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')
except Exception:
traceback.print_exc()
return 'INTERNAL SERVER', 500