import json import time import traceback from flask import Response, jsonify, request from llm_server.custom_redis import redis from . import openai_bp from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ... import opts from ...database.database import log_prompt from ...llm.generator import generator from ...llm.openai.oai_to_vllm import oai_to_vllm from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit # TODO: add rate-limit headers? @openai_bp.route('/chat/completions', methods=['POST']) def openai_chat_completions(): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body) if handler.cluster_backend_info['mode'] != 'vllm': # TODO: implement other backends raise NotImplementedError if not request_json_body.get('stream'): try: return handler.handle_request() except Exception: traceback.print_exc() return 'Internal server error', 500 else: if not opts.enable_streaming: # TODO: return a proper OAI error message return 'disabled', 401 if opts.openai_silent_trim: handler.request_json_body['messages'] = trim_messages_to_fit(request_json_body['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) response_status_code = 0 start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: return invalid_response else: if opts.openai_silent_trim: oai_messages = trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) else: oai_messages = handler.request.json['messages'] handler.prompt = transform_messages_to_prompt(oai_messages) handler.parameters = oai_to_vllm(handler.parameters, hashes=True, mode=handler.cluster_backend_info['mode']) msg_to_backend = { **handler.parameters, 'prompt': handler.prompt, 'stream': True, } try: response = generator(msg_to_backend, handler.backend_url) r_headers = dict(request.headers) r_url = request.url model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) def generate(): generated_text = '' partial_response = b'' for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): json_strs = partial_response.split(b'\x00') for json_str in json_strs: if json_str: try: json_obj = json.loads(json_str.decode()) new = json_obj['text'][0].split(handler.prompt + generated_text)[1] generated_text = generated_text + new except IndexError: # ???? continue data = { "id": f"chatcmpl-{oai_string}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [ { "index": 0, "delta": { "content": new }, "finish_reason": None } ] } yield f'data: {json.dumps(data)}\n\n' yield 'data: [DONE]\n\n' end_time = time.time() elapsed_time = end_time - start_time log_prompt( handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, handler.backend_url, ) return Response(generate(), mimetype='text/event-stream') except: # TODO: simulate OAI here raise Exception