import sqlite3 import time from typing import Tuple, Union import flask from flask import Response from llm_server import opts from llm_server.database import log_prompt from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.cache import redis from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.http import validate_json from llm_server.routes.queue import priority_queue from llm_server.routes.stats import SemaphoreCheckerThread DEFAULT_PRIORITY = 9999 class RequestHandler: def __init__(self, incoming_request: flask.Request): self.request = incoming_request _, self.request_json_body = validate_json(self.request) # routes need to validate it, here we just load it self.start_time = time.time() self.client_ip = self.get_client_ip() self.token = self.request.headers.get('X-Api-Key') self.priority = self.get_priority() self.backend = get_backend() self.parameters = None self.used = False SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() def get_client_ip(self): if self.request.headers.get('cf-connecting-ip'): return self.request.headers.get('cf-connecting-ip') elif self.request.headers.get('x-forwarded-for'): return self.request.headers.get('x-forwarded-for').split(',')[0] else: return self.request.remote_addr def get_priority(self): if self.token: conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (self.token,)) result = cursor.fetchone() conn.close() if result: return result[0] return DEFAULT_PRIORITY def get_parameters(self): if self.request_json_body.get('max_tokens'): self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) return parameters, parameters_invalid_msg def validate_request(self) -> Tuple[bool, Tuple[Response | None, int]]: self.parameters, parameters_invalid_msg = self.get_parameters() request_valid = False invalid_request_err_msg = None if self.parameters: request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters) if not request_valid: error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (not bool(parameters_invalid_msg), parameters_invalid_msg)] if not valid and msg] combined_error_message = ', '.join(error_messages) backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) return False, self.handle_error(backend_response) return True, (None, 0) def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]: prompt = llm_request['prompt'] if not self.is_client_ratelimited(): # Validate the prompt right before submission since the backend handler may have changed something. prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt) if not prompt_valid: backend_response = format_sillytavern_err(f'Validation Error: {invalid_prompt_err_msg}.', 'error') log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) return (False, None, None, 0), self.handle_error(backend_response) event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) else: event = None if not event: return (False, None, None, 0), self.handle_ratelimited() success, response, error_msg = event.wait() end_time = time.time() elapsed_time = end_time - self.start_time if response: try: # Be extra careful when getting attributes from the response object response_status_code = response.status_code except: response_status_code = 0 else: response_status_code = None # =============================================== # We encountered an error if not success or not response or error_msg: if not error_msg or error_msg == '': error_msg = 'Unknown error.' else: error_msg = error_msg.strip('.') + '.' backend_response = format_sillytavern_err(error_msg, 'error') log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) return (False, None, None, 0), self.handle_error(backend_response) # =============================================== response_valid_json, response_json_body = validate_json(response) return_json_err = False # The backend didn't send valid JSON if not response_valid_json: return_json_err = True # Make sure the backend didn't crap out. results = response_json_body.get('results', []) if len(results) and not results[0].get('text'): return_json_err = True if return_json_err: error_msg = 'The backend did not return valid JSON.' backend_response = format_sillytavern_err(error_msg, 'error') log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) return (False, None, None, 0), self.handle_error(backend_response) # =============================================== self.used = True return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) def is_client_ratelimited(self) -> bool: queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0: return False else: return True def handle_request(self) -> Tuple[flask.Response, int]: # Must include this in your child. # if self.used: # raise Exception('Can only use a RequestHandler object once.') raise NotImplementedError def handle_ratelimited(self) -> Tuple[flask.Response, int]: raise NotImplementedError def handle_error(self, msg: str) -> Tuple[flask.Response, int]: raise NotImplementedError def get_backend(): if opts.mode == 'oobabooga': return OobaboogaBackend() elif opts.mode == 'vllm': return VLLMBackend() else: raise Exception def delete_dict_key(d: dict, k: Union[str, list]): if isinstance(k, str): if k in d.keys(): del d[k] elif isinstance(k, list): for item in k: if item in d.keys(): del d[item] else: raise ValueError return d