import sqlite3 import time from typing import Union from flask import jsonify from llm_server import opts from llm_server.database import log_prompt from llm_server.llm.hf_textgen.hf_textgen_backend import HfTextgenLLMBackend from llm_server.llm.oobabooga.ooba_backend import OobaboogaLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.cache import redis from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.http import validate_json from llm_server.routes.queue import priority_queue from llm_server.routes.stats import SemaphoreCheckerThread DEFAULT_PRIORITY = 9999 def delete_dict_key(d: dict, k: Union[str, list]): if isinstance(k, str): if k in d.keys(): del d[k] elif isinstance(k, list): for item in k: if item in d.keys(): del d[item] else: raise ValueError return d class OobaRequestHandler: def __init__(self, incoming_request): self.request_json_body = None self.request = incoming_request self.start_time = time.time() self.client_ip = self.get_client_ip() self.token = self.request.headers.get('X-Api-Key') self.priority = self.get_priority() self.backend = self.get_backend() self.parameters = self.parameters_invalid_msg = None def validate_request(self) -> (bool, Union[str, None]): # TODO: move this to LLMBackend if self.parameters.get('max_new_tokens', 0) > opts.max_new_tokens or self.parameters.get('max_tokens', 0) > opts.max_new_tokens: return False, f'`max_new_tokens` must be less than or equal to {opts.max_new_tokens}' return True, None def get_client_ip(self): if self.request.headers.get('cf-connecting-ip'): return self.request.headers.get('cf-connecting-ip') elif self.request.headers.get('x-forwarded-for'): return self.request.headers.get('x-forwarded-for').split(',')[0] else: return self.request.remote_addr # def get_parameters(self): # # TODO: make this a LLMBackend method # return self.backend.get_parameters() def get_priority(self): if self.token: conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (self.token,)) result = cursor.fetchone() conn.close() if result: return result[0] return DEFAULT_PRIORITY def get_backend(self): if opts.mode == 'oobabooga': return OobaboogaLLMBackend() elif opts.mode == 'hf-textgen': return HfTextgenLLMBackend() elif opts.mode == 'vllm': return VLLMBackend() else: raise Exception def get_parameters(self): self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) def handle_request(self): request_valid_json, self.request_json_body = validate_json(self.request.data) if not request_valid_json: return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 self.get_parameters() SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time() request_valid, invalid_request_err_msg = self.validate_request() if not self.parameters: params_valid = False else: params_valid = True if not request_valid or not params_valid: error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid] combined_error_message = ', '.join(error_messages) err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True) # TODO: add a method to LLMBackend to return a formatted response string, since we have both Ooba and OpenAI response types return jsonify({ 'code': 400, 'msg': 'parameter validation error', 'results': [{'text': err}] }), 200 # Reconstruct the request JSON with the validated parameters and prompt. prompt = self.request_json_body.get('prompt', '') llm_request = {**self.parameters, 'prompt': prompt} queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0: event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.priority) else: # Client was rate limited event = None if not event: return self.handle_ratelimited() event.wait() success, response, error_msg = event.data end_time = time.time() elapsed_time = end_time - self.start_time return self.backend.handle_response(success, response, error_msg, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) def handle_ratelimited(self): backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, is_error=True) return jsonify({ 'results': [{'text': backend_response}] }), 200