local-llm-server/llm_server/routes/request_handler.py

98 lines
3.2 KiB
Python

import sqlite3
import time
from typing import Union
import flask
from llm_server import opts
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.cache import redis
from llm_server.routes.stats import SemaphoreCheckerThread
DEFAULT_PRIORITY = 9999
class RequestHandler:
def __init__(self, incoming_request: flask.Request):
self.request_json_body = None
self.request = incoming_request
self.start_time = time.time()
self.client_ip = self.get_client_ip()
self.token = self.request.headers.get('X-Api-Key')
self.priority = self.get_priority()
self.backend = get_backend()
self.parameters = self.parameters_invalid_msg = None
self.used = False
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
def get_client_ip(self):
if self.request.headers.get('cf-connecting-ip'):
return self.request.headers.get('cf-connecting-ip')
elif self.request.headers.get('x-forwarded-for'):
return self.request.headers.get('x-forwarded-for').split(',')[0]
else:
return self.request.remote_addr
def get_priority(self):
if self.token:
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (self.token,))
result = cursor.fetchone()
conn.close()
if result:
return result[0]
return DEFAULT_PRIORITY
def load_parameters(self):
# Handle OpenAI
if self.request_json_body.get('max_tokens'):
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
def validate_request(self):
self.load_parameters()
params_valid = False
request_valid = False
invalid_request_err_msg = None
if self.parameters:
params_valid = True
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
return (params_valid, self.parameters_invalid_msg), (request_valid, invalid_request_err_msg)
def is_client_ratelimited(self):
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
return False
else:
return True
def handle_request(self):
raise NotImplementedError
def handle_ratelimited(self):
raise NotImplementedError
def get_backend():
if opts.mode == 'oobabooga':
return OobaboogaBackend()
elif opts.mode == 'vllm':
return VLLMBackend()
else:
raise Exception
def delete_dict_key(d: dict, k: Union[str, list]):
if isinstance(k, str):
if k in d.keys():
del d[k]
elif isinstance(k, list):
for item in k:
if item in d.keys():
del d[item]
else:
raise ValueError
return d