import sqlite3 import time from typing import Union import flask from llm_server import opts from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.cache import redis from llm_server.routes.stats import SemaphoreCheckerThread DEFAULT_PRIORITY = 9999 class RequestHandler: def __init__(self, incoming_request: flask.Request): self.request_json_body = None self.request = incoming_request self.start_time = time.time() self.client_ip = self.get_client_ip() self.token = self.request.headers.get('X-Api-Key') self.priority = self.get_priority() self.backend = get_backend() self.parameters = self.parameters_invalid_msg = None self.used = False SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() def get_client_ip(self): if self.request.headers.get('cf-connecting-ip'): return self.request.headers.get('cf-connecting-ip') elif self.request.headers.get('x-forwarded-for'): return self.request.headers.get('x-forwarded-for').split(',')[0] else: return self.request.remote_addr def get_priority(self): if self.token: conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (self.token,)) result = cursor.fetchone() conn.close() if result: return result[0] return DEFAULT_PRIORITY def load_parameters(self): # Handle OpenAI if self.request_json_body.get('max_tokens'): self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) def validate_request(self): self.load_parameters() params_valid = False request_valid = False invalid_request_err_msg = None if self.parameters: params_valid = True request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters) return (params_valid, self.parameters_invalid_msg), (request_valid, invalid_request_err_msg) def is_client_ratelimited(self): queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0: return False else: return True def handle_request(self): raise NotImplementedError def handle_ratelimited(self): raise NotImplementedError def get_backend(): if opts.mode == 'oobabooga': return OobaboogaBackend() elif opts.mode == 'vllm': return VLLMBackend() else: raise Exception def delete_dict_key(d: dict, k: Union[str, list]): if isinstance(k, str): if k in d.keys(): del d[k] elif isinstance(k, list): for item in k: if item in d.keys(): del d[item] else: raise ValueError return d