import json import threading import time import traceback from flask import Response, jsonify, request from . import openai_bp from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler, build_openai_response, generate_oai_string from ... import opts from ...database.database import log_prompt from ...llm.generator import generator from ...llm.vllm import tokenize # TODO: add rate-limit headers? @openai_bp.route('/chat/completions', methods=['POST']) def openai_chat_completions(): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: handler = OpenAIRequestHandler(request, request_json_body) if request_json_body.get('stream'): if not opts.enable_streaming: # TODO: return a proper OAI error message return 'disabled', 401 if opts.mode != 'vllm': # TODO: implement other backends raise NotImplementedError response_status_code = 0 start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: # TODO: simulate OAI here raise Exception else: handler.prompt = handler.transform_messages_to_prompt() msg_to_backend = { **handler.parameters, 'prompt': handler.prompt, 'stream': True, } try: response = generator(msg_to_backend) r_headers = dict(request.headers) r_url = request.url model = opts.running_model if opts.openai_epose_our_model else request_json_body.get('model') def generate(): generated_text = '' partial_response = b'' for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): json_strs = partial_response.split(b'\x00') for json_str in json_strs: if json_str: try: json_obj = json.loads(json_str.decode()) new = json_obj['text'][0].split(handler.prompt + generated_text)[1] generated_text = generated_text + new except IndexError: # ???? continue data = { "id": f"chatcmpl-{generate_oai_string(30)}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [ { "index": 0, "delta": { "content": new }, "finish_reason": None } ] } yield f'data: {json.dumps(data)}\n\n' yield 'data: [DONE]\n\n' end_time = time.time() elapsed_time = end_time - start_time def background_task(): generated_tokens = tokenize(generated_text) log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) # TODO: use async/await instead of threads threading.Thread(target=background_task).start() return Response(generate(), mimetype='text/event-stream') except: # TODO: simulate OAI here raise Exception else: try: return handler.handle_request() except Exception as e: print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}') traceback.print_exc() print(request.data) return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500