import time from typing import Tuple, Union import flask from flask import Response, request from llm_server import opts from llm_server.database.conn import database from llm_server.database.database import log_prompt from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token from llm_server.routes.cache import redis from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue DEFAULT_PRIORITY = 9999 class RequestHandler: def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None): self.request = incoming_request self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' # Routes need to validate it, here we just load it if incoming_json: self.request_valid_json, self.request_json_body = validate_json(incoming_json) else: self.request_valid_json, self.request_json_body = validate_json(self.request) if not self.request_valid_json: raise Exception(f'Not valid JSON. Routes are supposed to reject invalid JSON.') self.start_time = time.time() self.client_ip = self.get_client_ip() self.token = self.get_auth_token() self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit() self.backend = get_backend() self.parameters = None self.used = False redis.zadd('recent_prompters', {self.client_ip: time.time()}) def get_auth_token(self): if self.request_json_body.get('X-API-KEY'): return self.request_json_body['X-API-KEY'] elif self.request.headers.get('X-Api-Key'): return self.request.headers['X-Api-Key'] elif self.request.headers.get('Authorization'): return parse_token(self.request.headers['Authorization']) def get_client_ip(self): if self.request.headers.get('X-Connecting-IP'): return self.request.headers.get('X-Connecting-IP') elif self.request.headers.get('Cf-Connecting-Ip'): return self.request.headers.get('Cf-Connecting-Ip') elif self.request.headers.get('X-Forwarded-For'): return self.request.headers.get('X-Forwarded-For').split(',')[0] else: return self.request.remote_addr def get_token_ratelimit(self): priority = DEFAULT_PRIORITY simultaneous_ip = opts.simultaneous_requests_per_ip if self.token: cursor = database.cursor() try: cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,)) result = cursor.fetchone() if result: priority, simultaneous_ip = result if simultaneous_ip is None: # No ratelimit for this token if null simultaneous_ip = 999999999 finally: cursor.close() return priority, simultaneous_ip def get_parameters(self): if self.request_json_body.get('max_tokens'): self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) return parameters, parameters_invalid_msg def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]: """ This needs to be called at the start of the subclass handle_request() method. :param prompt: :param do_log: :return: """ invalid_request_err_msgs = [] self.parameters, parameters_invalid_msg = self.get_parameters() # Parameters will be None if invalid. if self.parameters and not parameters_invalid_msg: # Backends shouldn't check max_new_tokens, but rather things specific to their backend. # Let the RequestHandler do the generic checks. if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens: invalid_request_err_msgs.append(f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}') if prompt: prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt) if not prompt_valid: invalid_request_err_msgs.append(invalid_prompt_err_msg) request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters, prompt, self.request) if not request_valid: invalid_request_err_msgs.append(invalid_request_err_msg) else: invalid_request_err_msgs.append(parameters_invalid_msg) if len(invalid_request_err_msgs): if len(invalid_request_err_msgs) > 1: # Format multiple error messages each on a new line. e = [f'\n{x}.' for x in invalid_request_err_msgs] combined_error_message = '\n'.join(e) else: # Otherwise, just grab the first and only one. combined_error_message = invalid_request_err_msgs[0] + '.' backend_response = self.handle_error(combined_error_message, 'Validation Error') if do_log: log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) return False, backend_response return True, (None, 0) def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]: prompt = llm_request['prompt'] if not self.is_client_ratelimited(): # Validate again before submission since the backend handler may have changed something. # Also, this is the first time we validate the prompt. request_valid, invalid_response = self.validate_request(prompt, do_log=True) if not request_valid: return (False, None, None, 0), invalid_response event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority) else: event = None if not event: return (False, None, None, 0), self.handle_ratelimited() success, response, error_msg = event.wait() end_time = time.time() elapsed_time = end_time - self.start_time if response: try: # Be extra careful when getting attributes from the response object response_status_code = response.status_code except: response_status_code = 0 else: response_status_code = None # =============================================== # We encountered an error if not success or not response or error_msg: if not error_msg or error_msg == '': error_msg = 'Unknown error.' else: error_msg = error_msg.strip('.') + '.' backend_response = self.handle_error(error_msg) log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) return (False, None, None, 0), backend_response # =============================================== response_valid_json, response_json_body = validate_json(response) return_json_err = False # The backend didn't send valid JSON if not response_valid_json: return_json_err = True # Make sure the backend didn't crap out. results = response_json_body.get('results', []) if len(results) and not results[0].get('text'): return_json_err = True if return_json_err: error_msg = 'The backend did not return valid JSON.' backend_response = self.handle_error(error_msg) log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) return (False, None, None, 0), backend_response # =============================================== self.used = True return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) def is_client_ratelimited(self) -> bool: queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip)) x = redis.hget('processing_ips', self.client_ip) if x: processing_ip = int(x) else: processing_ip = 0 if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0: return False else: print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.') return True def handle_request(self) -> Tuple[flask.Response, int]: # Must include this in your child. # if self.used: # raise Exception('Can only use a RequestHandler object once.') raise NotImplementedError def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]: raise NotImplementedError def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: raise NotImplementedError def get_backend(): if opts.mode == 'oobabooga': return OobaboogaBackend() elif opts.mode == 'vllm': return VLLMBackend() else: raise Exception def delete_dict_key(d: dict, k: Union[str, list]): if isinstance(k, str): if k in d.keys(): del d[k] elif isinstance(k, list): for item in k: if item in d.keys(): del d[item] else: raise ValueError return d def before_request(): auto_set_base_client_api(request) if request.endpoint != 'v1.get_stats': response = require_api_key() if response is not None: return response