import json import time import traceback from flask import Response, jsonify, request from llm_server.custom_redis import redis from . import openai_bp from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts from ...database.database import log_prompt from ...llm.generator import generator from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit # TODO: add rate-limit headers? @openai_bp.route('/chat/completions', methods=['POST']) def openai_chat_completions(): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body) if not request_json_body.get('stream'): try: invalid_oai_err_msg = validate_oai(request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg return handler.handle_request() except Exception: traceback.print_exc() return 'Internal server error', 500 else: if not opts.enable_streaming: return 'DISABLED', 401 invalid_oai_err_msg = validate_oai(handler.request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg handler.request_json_body = oai_to_vllm(handler.request_json_body, hashes=False, mode=handler.cluster_backend_info['mode']) if opts.openai_silent_trim: handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)) else: handler.prompt = transform_messages_to_prompt(handler.request.json['messages']) response_status_code = 0 start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: return invalid_response else: msg_to_backend = { **handler.parameters, 'prompt': handler.prompt, 'stream': True, } # Add a dummy event to the queue and wait for it to reach a worker event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) if not event: log_prompt( handler.client_ip, handler.token, handler.prompt, None, None, handler.parameters, request.headers, response_status_code, request.url, handler.backend_url, ) return handler.handle_ratelimited() # Wait for a worker to get our request and discard it. _, _, _ = event.wait() try: response = generator(msg_to_backend, handler.backend_url) r_headers = dict(request.headers) r_url = request.url model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) def generate(): try: generated_text = '' partial_response = b'' for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): json_strs = partial_response.split(b'\x00') for json_str in json_strs: if json_str: try: json_obj = json.loads(json_str.decode()) new = json_obj['text'][0].split(handler.prompt + generated_text)[1] generated_text = generated_text + new except IndexError: # ???? continue data = { "id": f"chatcmpl-{oai_string}", "object": "chat.completion.chunk", "created": int(time.time()), "model": model, "choices": [ { "index": 0, "delta": { "content": new }, "finish_reason": None } ] } yield f'data: {json.dumps(data)}\n\n' yield 'data: [DONE]\n\n' end_time = time.time() elapsed_time = end_time - start_time log_prompt( handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, handler.backend_url, ) finally: # The worker incremented it, we'll decrement it. decrement_ip_count(handler.client_ip, 'processing_ips') decr_active_workers(handler.selected_model, handler.backend_url) return Response(generate(), mimetype='text/event-stream') except Exception: traceback.print_exc() return 'INTERNAL SERVER', 500