import time from typing import Tuple, Union import flask from flask import Response, request from llm_server import opts from llm_server.cluster.backend import get_a_cluster_backend from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis from llm_server.database.database import get_token_ratelimit, log_prompt from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue class RequestHandler: def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None): self.request = incoming_request self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' # Routes need to validate it, here we just load it if incoming_json: self.request_valid_json, self.request_json_body = validate_json(incoming_json) else: self.request_valid_json, self.request_json_body = validate_json(self.request) if not self.request_valid_json: raise Exception(f'Not valid JSON. Routes are supposed to reject invalid JSON.') self.start_time = time.time() self.client_ip = self.get_client_ip() self.token = self.get_auth_token() self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token) self.backend_url = get_a_cluster_backend(selected_model) self.cluster_backend_info = cluster_config.get_backend(self.backend_url) self.selected_model = self.cluster_backend_info['model'] if not self.cluster_backend_info.get('mode'): print(selected_model, self.backend_url, self.cluster_backend_info) self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url) self.parameters = None self.used = False if not self.token.startswith('SYSTEM__'): # "recent_prompters" is only used for stats. redis.zadd('recent_prompters', {self.client_ip: time.time()}) def get_auth_token(self): if self.request_json_body.get('X-API-KEY'): return self.request_json_body['X-API-KEY'] elif self.request.headers.get('X-Api-Key'): return self.request.headers['X-Api-Key'] elif self.request.headers.get('Authorization'): return parse_token(self.request.headers['Authorization']) def get_client_ip(self): if self.request.headers.get('Llm-Connecting-Ip'): return self.request.headers['Llm-Connecting-Ip'] if self.request.headers.get('X-Connecting-IP'): return self.request.headers.get('X-Connecting-IP') elif self.request.headers.get('Cf-Connecting-Ip'): return self.request.headers.get('Cf-Connecting-Ip') elif self.request.headers.get('X-Forwarded-For'): return self.request.headers.get('X-Forwarded-For').split(',')[0] else: return self.request.remote_addr def get_parameters(self): if self.request_json_body.get('max_tokens'): self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) return parameters, parameters_invalid_msg def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]: """ This needs to be called at the start of the subclass handle_request() method. :param prompt: :param do_log: :return: """ invalid_request_err_msgs = [] self.parameters, parameters_invalid_msg = self.get_parameters() # Parameters will be None if invalid. if self.parameters and not parameters_invalid_msg: # Backends shouldn't check max_new_tokens, but rather things specific to their backend. # Let the RequestHandler do the generic checks. if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens: invalid_request_err_msgs.append(f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}') if prompt: prompt_valid, invalid_prompt_err_msg = self.backend.validate_prompt(prompt) if not prompt_valid: invalid_request_err_msgs.append(invalid_prompt_err_msg) request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters, prompt, self.request) if not request_valid: invalid_request_err_msgs.append(invalid_request_err_msg) else: invalid_request_err_msgs.append(parameters_invalid_msg) if len(invalid_request_err_msgs): if len(invalid_request_err_msgs) > 1: # Format multiple error messages each on a new line. e = [f'\n{x}.' for x in invalid_request_err_msgs] combined_error_message = '\n'.join(e) else: # Otherwise, just grab the first and only one. combined_error_message = invalid_request_err_msgs[0] + '.' backend_response = self.handle_error(combined_error_message, 'Validation Error') if do_log: log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True) return False, backend_response return True, (None, 0) def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]: prompt = llm_request['prompt'] if not self.is_client_ratelimited(): # Validate again before submission since the backend handler may have changed something. # Also, this is the first time we validate the prompt. request_valid, invalid_response = self.validate_request(prompt, do_log=True) if not request_valid: return (False, None, None, 0), invalid_response event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.backend_url), self.token_priority, self.selected_model) else: event = None if not event: return (False, None, None, 0), self.handle_ratelimited() # TODO: add wait timeout success, response, error_msg = event.wait() end_time = time.time() elapsed_time = end_time - self.start_time if response: try: # Be extra careful when getting attributes from the response object response_status_code = response.status_code except: response_status_code = 0 else: response_status_code = None # =============================================== # We encountered an error if not success or not response or error_msg: if not error_msg or error_msg == '': error_msg = 'Unknown error.' else: error_msg = error_msg.strip('.') + '.' backend_response = self.handle_error(error_msg) log_prompt(ip=self.client_ip, token=self.token, prompt=prompt, response=backend_response[0].data.decode('utf-8'), gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=response_status_code, request_url=self.request.url, backend_url=self.backend_url, is_error=True) return (False, None, None, 0), backend_response # =============================================== response_valid_json, response_json_body = validate_json(response) return_json_err = False # The backend didn't send valid JSON if not response_valid_json: return_json_err = True # Make sure the backend didn't crap out. results = response_json_body.get('results', []) if len(results) and not results[0].get('text'): return_json_err = True if return_json_err: error_msg = 'The backend did not return valid JSON.' backend_response = self.handle_error(error_msg) log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True) return (False, None, None, 0), backend_response # =============================================== self.used = True return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) def is_client_ratelimited(self) -> bool: if self.token_priority == 0: return False queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip)) x = redis.hget('processing_ips', self.client_ip) if x: processing_ip = int(x) else: processing_ip = 0 if queued_ip_count + processing_ip >= self.token_simultaneous_ip: print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.') return True else: return False def handle_request(self) -> Tuple[flask.Response, int]: # Must include this in your child. # if self.used: # raise Exception('Can only use a RequestHandler object once.') raise NotImplementedError def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]: raise NotImplementedError def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: raise NotImplementedError def get_backend_handler(mode, backend_url: str): if mode == 'oobabooga': return OobaboogaBackend(backend_url) elif mode == 'vllm': return VLLMBackend(backend_url) else: raise Exception def delete_dict_key(d: dict, k: Union[str, list]): if isinstance(k, str): if k in d.keys(): del d[k] elif isinstance(k, list): for item in k: if item in d.keys(): del d[item] else: raise ValueError return d def before_request(): auto_set_base_client_api(request) if request.endpoint != 'v1.get_stats': response = require_api_key() if response is not None: return response