import time import traceback import simplejson as json import ujson from flask import Response, jsonify, request from redis import Redis from llm_server.custom_redis import redis from . import openai_bp, openai_model_bp from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import priority_queue from ...config.global_config import GlobalConfig from ...database.log_to_db import log_to_db from ...llm import get_token_count from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.transform import generate_oai_string, trim_string_to_fit from ...logging import create_logger # TODO: add rate-limit headers? _logger = create_logger('OpenAICompletions') @openai_bp.route('/completions', methods=['POST']) @openai_model_bp.route('//v1/completions', methods=['POST']) def openai_completions(model_name=None): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) if handler.offline: return return_invalid_model_err(model_name) if handler.cluster_backend_info['mode'] != 'vllm': # TODO: implement other backends raise NotImplementedError invalid_oai_err_msg = validate_oai(handler.request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode']) if GlobalConfig.get().openai_silent_trim: handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) else: # The handle_request() call below will load the prompt so we don't have # to do anything else here. pass handler.request_json_body['prompt'] = handler.prompt if not request_json_body.get('stream'): invalid_oai_err_msg = validate_oai(request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg response, status_code = handler.handle_request(return_ok=False) if status_code == 429: return handler.handle_ratelimited() output = response.json['results'][0]['text'] prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url) response_tokens = get_token_count(output, handler.backend_url) running_model = redis.get('running_model', 'ERROR', dtype=str) response = jsonify({ "id": f"cmpl-{generate_oai_string(30)}", "object": "text_completion", "created": int(time.time()), "model": running_model if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model'), "choices": [ { "text": output, "index": 0, "logprobs": None, "finish_reason": "stop" } ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": response_tokens, "total_tokens": prompt_tokens + response_tokens } }) # TODO: # stats = redis.get('proxy_stats', dtype=dict) # if stats: # response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] return response, 200 else: if not GlobalConfig.get().enable_streaming: return 'Streaming disabled', 403 request_valid, invalid_response = handler.validate_request() if not request_valid: return invalid_response handler.parameters, _ = handler.get_parameters() handler.request_json_body = { 'prompt': handler.request_json_body['prompt'], 'model': handler.request_json_body['model'], **handler.parameters } invalid_oai_err_msg = validate_oai(handler.request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg if GlobalConfig.get().openai_silent_trim: handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']] if not handler.prompt: # Prevent issues on the backend. return 'Invalid prompt', 400 start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: return invalid_response event = None if not handler.is_client_ratelimited(): event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) if not event: log_to_db( handler.client_ip, handler.token, handler.prompt, None, None, handler.parameters, request.headers, 429, request.url, handler.backend_url, ) return handler.handle_ratelimited() try: r_headers = dict(request.headers) r_url = request.url model = redis.get('running_model', 'ERROR', dtype=str) if GlobalConfig.get().openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) _, stream_name, error_msg = event.wait() if error_msg: _logger.error(f'OAI failed to start streaming: {error_msg}') stream_name = None return 'Request Timeout', 408 def generate(): stream_redis = Redis(db=8) generated_text = '' try: last_id = '0-0' while True: stream_data = stream_redis.xread({stream_name: last_id}, block=GlobalConfig.get().redis_stream_timeout) if not stream_data: _logger.debug(f"No message received in {GlobalConfig.get().redis_stream_timeout / 1000} seconds, closing stream.") yield 'data: [DONE]\n\n' else: for stream_index, item in stream_data[0][1]: last_id = stream_index timestamp = int(stream_index.decode('utf-8').split('-')[0]) data = ujson.loads(item[b'data']) if data['error']: _logger.error(f'OAI streaming encountered error: {data["error"]}') yield 'data: [DONE]\n\n' return elif data['new']: response = { "id": f"cmpl-{oai_string}", "object": "text_completion", "created": timestamp, "model": model, "choices": [ { "index": 0, "delta": { "content": data['new'] }, "finish_reason": None } ] } generated_text = generated_text + data['new'] yield f'data: {json.dumps(response)}\n\n' elif data['completed']: yield 'data: [DONE]\n\n' end_time = time.time() elapsed_time = end_time - start_time log_to_db( handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, 200, r_url, handler.backend_url, ) return except GeneratorExit: # This should be triggered if a client disconnects early. return except Exception: traceback.print_exc() yield 'data: [DONE]\n\n' finally: if event: redis.publish(f'notifications:{event.event_id}', 'canceled') if stream_name: stream_redis.delete(stream_name) return Response(generate(), mimetype='text/event-stream') except Exception: traceback.print_exc() return 'INTERNAL SERVER', 500