import json import time import traceback import ujson from flask import Response, jsonify, request from redis import Redis from llm_server.custom_redis import redis from . import openai_bp, openai_model_bp from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ..queue import priority_queue from ...config.global_config import GlobalConfig from ...database.log_to_db import log_to_db from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from ...logging import create_logger _logger = create_logger('OpenAIChatCompletions') # TODO: add rate-limit headers? @openai_bp.route('/chat/completions', methods=['POST']) @openai_model_bp.route('//v1/chat/completions', methods=['POST']) def openai_chat_completions(model_name=None): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) if handler.offline: return return_oai_internal_server_error(f'backend {handler.backend_url} is offline') if not request_json_body.get('stream'): try: return handler.handle_request() except Exception: traceback.print_exc() return 'Internal server error', 500 else: if not GlobalConfig.get().enable_streaming: return 'Streaming disabled', 403 invalid_oai_err_msg = validate_oai(handler.request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode']) handler.parameters, e = handler.get_parameters() handler.request_json_body = { 'messages': handler.request_json_body['messages'], 'model': handler.request_json_body['model'], **handler.parameters } if GlobalConfig.get().openai_silent_trim: handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)) else: handler.prompt = transform_messages_to_prompt(handler.request.json['messages']) if not handler.prompt: # Prevent issues on the backend. return 'Invalid prompt', 400 # Need to set the prompt in the JSON body since that's what the inference worker expects. handler.request_json_body['prompt'] = handler.prompt start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: return invalid_response event = None if not handler.is_client_ratelimited(): event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) if not event: log_to_db( handler.client_ip, handler.token, handler.prompt, None, None, handler.parameters, request.headers, 429, request.url, handler.backend_url, ) return handler.handle_ratelimited() try: r_headers = dict(request.headers) r_url = request.url model = redis.get('running_model', 'ERROR', dtype=str) if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) # Need to do this before we enter generate() since we want to be able to # return a 408 if necessary. _, stream_name, error_msg = event.wait() if error_msg: _logger.error(f'OAI failed to start streaming: {error_msg}') stream_name = None # set to null so that the Finally ignores it. return 'Request Timeout', 408 def generate(): stream_redis = Redis(db=8) generated_text = '' try: last_id = '0-0' while True: stream_data = stream_redis.xread({stream_name: last_id}, block=GlobalConfig.get().redis_stream_timeout) if not stream_data: _logger.debug(f"No message received in {GlobalConfig.get().redis_stream_timeout / 1000} seconds, closing stream.") yield 'data: [DONE]\n\n' else: for stream_index, item in stream_data[0][1]: last_id = stream_index timestamp = int(stream_index.decode('utf-8').split('-')[0]) data = ujson.loads(item[b'data']) if data['error']: # Not printing error since we can just check the daemon log. _logger.warn(f'OAI streaming encountered error: {data["error"]}') yield 'data: [DONE]\n\n' return elif data['new']: response = { "id": f"chatcmpl-{oai_string}", "object": "chat.completion.chunk", "created": timestamp, "model": model, "choices": [ { "index": 0, "delta": { "content": data['new'] }, "finish_reason": None } ] } generated_text = generated_text + data['new'] yield f'data: {json.dumps(response)}\n\n' elif data['completed']: yield 'data: [DONE]\n\n' end_time = time.time() elapsed_time = end_time - start_time log_to_db( handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, 200, r_url, handler.backend_url, ) return except GeneratorExit: return except Exception: traceback.print_exc() yield 'data: [DONE]\n\n' finally: if event: redis.publish(f'notifications:{event.event_id}', 'canceled') if stream_name: stream_redis.delete(stream_name) return Response(generate(), mimetype='text/event-stream') except Exception: traceback.print_exc() return 'INTERNAL SERVER', 500