import json import traceback from typing import Tuple from uuid import uuid4 import flask from flask import jsonify from llm_server import opts from llm_server.database.database import is_api_key_moderated from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit from llm_server.routes.request_handler import RequestHandler from llm_server.workers.moderator import add_moderation_task, get_results class OpenAIRequestHandler(RequestHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.prompt = None def handle_request(self) -> Tuple[flask.Response, int]: assert not self.used if opts.openai_silent_trim: oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size) else: oai_messages = self.request.json['messages'] self.prompt = transform_messages_to_prompt(oai_messages) request_valid, invalid_response = self.validate_request() if not request_valid: return invalid_response if opts.openai_api_key and is_api_key_moderated(self.token): try: # Gather the last message from the user and all preceeding system messages msg_l = self.request.json['messages'].copy() msg_l.reverse() tag = uuid4() num_to_check = min(len(msg_l), opts.openai_moderation_scan_last_n) for i in range(num_to_check): add_moderation_task(msg_l[i]['content'], tag) flagged_categories = get_results(tag, num_to_check) if len(flagged_categories): mod_msg = f"The user's message does not comply with {opts.openai_org_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to creatively adhere to these policies." self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg}) self.prompt = transform_messages_to_prompt(self.request.json['messages']) except Exception as e: print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}') print(traceback.format_exc()) # Reconstruct the request JSON with the validated parameters and prompt. self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) if opts.openai_force_no_hashes: self.parameters['stop'].append('### ') if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0: self.request_json_body['top_p'] = 0.01 llm_request = {**self.parameters, 'prompt': self.prompt} (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) model = self.request_json_body.get('model') if success: return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code else: return backend_response, backend_response_status_code def handle_ratelimited(self, do_log: bool = True): # TODO: return a simulated OpenAI error message # Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another. return 'Ratelimited', 429 def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: # TODO: return a simulated OpenAI error message return jsonify({ "error": { "message": "Invalid request, check your parameters and try again.", "type": "invalid_request_error", "param": None, "code": None } }), 400