import json import re import time import traceback from typing import Tuple from uuid import uuid4 import flask from flask import Response, jsonify, make_response, request from llm_server.cluster.backend import get_model_choices from llm_server.config.global_config import GlobalConfig from llm_server.custom_redis import redis from llm_server.database.database import is_api_key_moderated, is_valid_api_key from llm_server.database.log_to_db import log_to_db from llm_server.llm import get_token_count from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_oai_internal_server_error, return_oai_invalid_request_error from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.logging import create_logger from llm_server.routes.auth import parse_token from llm_server.routes.request_handler import RequestHandler from llm_server.workers.moderator import add_moderation_task, get_results _logger = create_logger('OpenAIRequestHandler') class OpenAIRequestHandler(RequestHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.prompt = None def handle_request(self) -> Tuple[flask.Response, int]: assert not self.used if self.offline: return return_oai_internal_server_error(f'backend {self.backend_url} is offline.') disable_openai_handling = request.headers.get('Llm-Disable-Openai', False) == 'true' \ and is_valid_api_key(parse_token(request.headers.get('Authorization', ''))) \ and parse_token(request.headers.get('Authorization', '')).startswith('SYSTEM__') if GlobalConfig.get().openai_silent_trim: oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url) else: oai_messages = self.request.json['messages'] self.prompt = transform_messages_to_prompt(oai_messages, disable_openai_handling) self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) request_valid, invalid_response = self.validate_request() if not request_valid: return invalid_response if not self.prompt: # TODO: format this as an openai error message return return_oai_invalid_request_error('Invalid prompt'), 400 # TODO: support Ooba backend self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) invalid_oai_err_msg = validate_oai(self.request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg if not disable_openai_handling and (GlobalConfig.get().openai_moderation_enabled and GlobalConfig.get().openai_api_key and is_api_key_moderated(self.token)): try: # Gather the last message from the user and all preceding system messages msg_l = self.request.json['messages'].copy() msg_l.reverse() tag = uuid4() num_to_check = min(len(msg_l), GlobalConfig.get().openai_moderation_scan_last_n) for i in range(num_to_check): add_moderation_task(msg_l[i]['content'], tag) flagged_categories = get_results(tag, num_to_check) if len(flagged_categories): mod_msg = f"The user's message does not comply with {GlobalConfig.get().openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies." self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg}) self.prompt = transform_messages_to_prompt(self.request.json['messages']) except Exception as e: _logger.error(f'OpenAI moderation endpoint failed: {e.__class__.__name__}: {e}') traceback.print_exc() llm_request = {**self.parameters, 'prompt': self.prompt} (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) model = self.request_json_body.get('model') if success: return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code else: return backend_response, backend_response_status_code def handle_ratelimited(self, do_log: bool = True): model_choices, default_model = get_model_choices() default_model_info = model_choices[default_model] w = int(default_model_info['estimated_wait']) if default_model_info['estimated_wait'] > 0 else 2 response = jsonify({ "error": { "message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.", "type": "rate_limit_exceeded", "param": None, "code": None } }) response.headers['x-ratelimit-limit-requests'] = '2' response.headers['x-ratelimit-remaining-requests'] = '0' response.headers['x-ratelimit-reset-requests'] = f"{w}s" if do_log: log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) return response, 429 def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: _logger.error(f'OAI Error: {error_msg}') return return_oai_invalid_request_error() def build_openai_response(self, prompt, response, model=None): # Seperate the user's prompt from the context x = prompt.split('### USER:') if len(x) > 1: prompt = re.sub(r'\n$', '', x[-1].strip(' ')) # Make sure the bot doesn't put any other instructions in its response response = re.sub(ANTI_RESPONSE_RE, '', response) response = re.sub(ANTI_CONTINUATION_RE, '', response) prompt_tokens = get_token_count(prompt, self.backend_url) response_tokens = get_token_count(response, self.backend_url) running_model = redis.get('running_model', 'ERROR', dtype=str) response = make_response(jsonify({ "id": f"chatcmpl-{generate_oai_string(30)}", "object": "chat.completion", "created": int(time.time()), "model": running_model if GlobalConfig.get().openai_expose_our_model else model, "choices": [{ "index": 0, "message": { "role": "assistant", "content": response, }, "logprobs": None, "finish_reason": "stop" }], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": response_tokens, "total_tokens": prompt_tokens + response_tokens } }), 200) return response def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]: self.parameters, parameters_invalid_msg = self.get_parameters() if not self.parameters: _logger.error(f'OAI BACKEND VALIDATION ERROR: {parameters_invalid_msg}') return False, (Response('Invalid request, check your parameters and try again.'), 400) invalid_oai_err_msg = validate_oai(self.parameters) if invalid_oai_err_msg: return False, invalid_oai_err_msg # self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) # If the parameters were invalid, let the superclass deal with it. return super().validate_request(prompt, do_log)