import json import threading import time import traceback from flask import Response, jsonify, request from . import openai_bp from ..cache import redis from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ... import opts from ...database.database import log_prompt from ...llm.generator import generator from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt from ...llm.vllm import tokenize # TODO: add rate-limit headers? @openai_bp.route('/chat/completions', methods=['POST']) def openai_chat_completions(): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: handler = OpenAIRequestHandler(request, request_json_body) if request_json_body.get('stream'): if not opts.enable_streaming: # TODO: return a proper OAI error message return 'disabled', 401 if opts.mode != 'vllm': # TODO: implement other backends raise NotImplementedError response_status_code = 0 start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: # TODO: simulate OAI here raise Exception('TODO: simulate OAI here') else: handler.prompt = transform_messages_to_prompt(request_json_body['messages']) msg_to_backend = { **handler.parameters, 'prompt': handler.prompt, 'stream': True, } try: response = generator(msg_to_backend) r_headers = dict(request.headers) r_url = request.url model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) def generate(): generated_text = '' partial_response = b'' for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): json_strs = partial_response.split(b'\x00') for json_str in json_strs: if json_str: try: json_obj = json.loads(json_str.decode()) new = json_obj['text'][0].split(handler.prompt + generated_text)[1] generated_text = generated_text + new except IndexError: # ???? continue data = { "id": f"chatcmpl-{oai_string}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [ { "index": 0, "delta": { "content": new }, "finish_reason": None } ] } yield f'data: {json.dumps(data)}\n\n' yield 'data: [DONE]\n\n' end_time = time.time() elapsed_time = end_time - start_time def background_task(): generated_tokens = tokenize(generated_text) log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) # TODO: use async/await instead of threads thread = threading.Thread(target=background_task) thread.start() thread.join() return Response(generate(), mimetype='text/event-stream') except: # TODO: simulate OAI here raise Exception else: try: return handler.handle_request() except Exception: traceback.print_exc() return 'Internal server error', 500