import time import traceback import simplejson as json from flask import Response, jsonify, request from llm_server.custom_redis import redis from . import openai_bp from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts from ...database.database import do_db_log from ...database.log_to_db import log_to_db from ...llm import get_token_count from ...llm.generator import generator from ...llm.openai.oai_to_vllm import oai_to_vllm, validate_oai from ...llm.openai.transform import generate_oai_string, trim_string_to_fit # TODO: add rate-limit headers? @openai_bp.route('/completions', methods=['POST']) def openai_completions(): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: handler = OobaRequestHandler(incoming_request=request) if handler.cluster_backend_info['mode'] != 'vllm': # TODO: implement other backends raise NotImplementedError invalid_oai_err_msg = validate_oai(handler.request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode']) if opts.openai_silent_trim: handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) else: # The handle_request() call below will load the prompt so we don't have # to do anything else here. pass if not request_json_body.get('stream'): invalid_oai_err_msg = validate_oai(request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg response, status_code = handler.handle_request(return_ok=False) if status_code == 429: return handler.handle_ratelimited() output = response.json['results'][0]['text'] # TODO: async/await prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url) response_tokens = get_token_count(output, handler.backend_url) running_model = redis.get('running_model', 'ERROR', dtype=str) response = jsonify({ "id": f"cmpl-{generate_oai_string(30)}", "object": "text_completion", "created": int(time.time()), "model": running_model if opts.openai_expose_our_model else request_json_body.get('model'), "choices": [ { "text": output, "index": 0, "logprobs": None, "finish_reason": "stop" } ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": response_tokens, "total_tokens": prompt_tokens + response_tokens } }) stats = redis.get('proxy_stats', dtype=dict) if stats: response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] return response, 200 else: if not opts.enable_streaming: return 'DISABLED', 401 response_status_code = 0 start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: return invalid_response else: handler.prompt = handler.request_json_body['prompt'] msg_to_backend = { **handler.parameters, 'prompt': handler.prompt, 'stream': True, } # Add a dummy event to the queue and wait for it to reach a worker event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) if not event: log_to_db( handler.client_ip, handler.token, handler.prompt, None, None, handler.parameters, request.headers, response_status_code, request.url, handler.backend_url, ) return handler.handle_ratelimited() # Wait for a worker to get our request and discard it. _, _, _ = event.wait() try: response = generator(msg_to_backend, handler.backend_url) r_headers = dict(request.headers) r_url = request.url model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) def generate(): try: generated_text = '' partial_response = b'' for chunk in response.iter_content(chunk_size=1): partial_response += chunk if partial_response.endswith(b'\x00'): json_strs = partial_response.split(b'\x00') for json_str in json_strs: if json_str: try: json_obj = json.loads(json_str.decode()) new = json_obj['text'][0].split(handler.prompt + generated_text)[1] generated_text = generated_text + new except IndexError: # ???? continue data = { "id": f"cmpl-{oai_string}", "object": "text_completion", "created": int(time.time()), "model": model, "choices": [ { "index": 0, "delta": { "content": new }, "finish_reason": None } ] } yield f'data: {json.dumps(data)}\n\n' yield 'data: [DONE]\n\n' end_time = time.time() elapsed_time = end_time - start_time log_to_db( handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, handler.backend_url, ) finally: # The worker incremented it, we'll decrement it. decrement_ip_count(handler.client_ip, 'processing_ips') decr_active_workers(handler.selected_model, handler.backend_url) return Response(generate(), mimetype='text/event-stream') except Exception: traceback.print_exc() return 'INTERNAL SERVER', 500