2023-08-30 18:53:26 -06:00
|
|
|
import sqlite3
|
|
|
|
import time
|
|
|
|
from typing import Union
|
|
|
|
|
2023-09-13 11:22:33 -06:00
|
|
|
import flask
|
|
|
|
|
2023-08-30 18:53:26 -06:00
|
|
|
from llm_server import opts
|
2023-09-12 16:40:09 -06:00
|
|
|
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
2023-09-11 20:47:19 -06:00
|
|
|
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
2023-08-30 18:53:26 -06:00
|
|
|
from llm_server.routes.cache import redis
|
|
|
|
from llm_server.routes.stats import SemaphoreCheckerThread
|
|
|
|
|
|
|
|
DEFAULT_PRIORITY = 9999
|
|
|
|
|
|
|
|
|
2023-09-12 16:40:09 -06:00
|
|
|
class RequestHandler:
|
2023-09-13 11:22:33 -06:00
|
|
|
def __init__(self, incoming_request: flask.Request):
|
2023-08-30 18:53:26 -06:00
|
|
|
self.request_json_body = None
|
|
|
|
self.request = incoming_request
|
|
|
|
self.start_time = time.time()
|
|
|
|
self.client_ip = self.get_client_ip()
|
|
|
|
self.token = self.request.headers.get('X-Api-Key')
|
|
|
|
self.priority = self.get_priority()
|
2023-09-12 16:40:09 -06:00
|
|
|
self.backend = get_backend()
|
2023-09-12 01:04:11 -06:00
|
|
|
self.parameters = self.parameters_invalid_msg = None
|
2023-09-12 16:40:09 -06:00
|
|
|
self.used = False
|
|
|
|
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
2023-08-30 18:53:26 -06:00
|
|
|
|
|
|
|
def get_client_ip(self):
|
|
|
|
if self.request.headers.get('cf-connecting-ip'):
|
|
|
|
return self.request.headers.get('cf-connecting-ip')
|
|
|
|
elif self.request.headers.get('x-forwarded-for'):
|
|
|
|
return self.request.headers.get('x-forwarded-for').split(',')[0]
|
|
|
|
else:
|
|
|
|
return self.request.remote_addr
|
|
|
|
|
|
|
|
def get_priority(self):
|
|
|
|
if self.token:
|
|
|
|
conn = sqlite3.connect(opts.database_path)
|
|
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (self.token,))
|
|
|
|
result = cursor.fetchone()
|
|
|
|
conn.close()
|
|
|
|
if result:
|
|
|
|
return result[0]
|
|
|
|
return DEFAULT_PRIORITY
|
|
|
|
|
2023-09-12 16:40:09 -06:00
|
|
|
def load_parameters(self):
|
|
|
|
# Handle OpenAI
|
|
|
|
if self.request_json_body.get('max_tokens'):
|
|
|
|
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
2023-09-12 01:04:11 -06:00
|
|
|
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
|
|
|
|
2023-09-12 16:40:09 -06:00
|
|
|
def validate_request(self):
|
|
|
|
self.load_parameters()
|
2023-09-12 10:30:45 -06:00
|
|
|
params_valid = False
|
|
|
|
request_valid = False
|
|
|
|
invalid_request_err_msg = None
|
|
|
|
if self.parameters:
|
2023-09-12 01:04:11 -06:00
|
|
|
params_valid = True
|
2023-09-12 16:40:09 -06:00
|
|
|
request_valid, invalid_request_err_msg = self.backend.validate_request(self.parameters)
|
|
|
|
return (params_valid, self.parameters_invalid_msg), (request_valid, invalid_request_err_msg)
|
2023-09-12 01:04:11 -06:00
|
|
|
|
2023-09-12 16:40:09 -06:00
|
|
|
def is_client_ratelimited(self):
|
2023-08-30 18:53:26 -06:00
|
|
|
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
2023-09-11 20:47:19 -06:00
|
|
|
if queued_ip_count < opts.simultaneous_requests_per_ip or self.priority == 0:
|
2023-09-12 16:40:09 -06:00
|
|
|
return False
|
2023-08-30 18:53:26 -06:00
|
|
|
else:
|
2023-09-12 16:40:09 -06:00
|
|
|
return True
|
|
|
|
|
|
|
|
def handle_request(self):
|
|
|
|
raise NotImplementedError
|
2023-08-30 18:53:26 -06:00
|
|
|
|
2023-09-12 16:40:09 -06:00
|
|
|
def handle_ratelimited(self):
|
|
|
|
raise NotImplementedError
|
2023-09-12 01:04:11 -06:00
|
|
|
|
2023-08-30 18:53:26 -06:00
|
|
|
|
2023-09-12 16:40:09 -06:00
|
|
|
def get_backend():
|
|
|
|
if opts.mode == 'oobabooga':
|
|
|
|
return OobaboogaBackend()
|
|
|
|
elif opts.mode == 'vllm':
|
|
|
|
return VLLMBackend()
|
|
|
|
else:
|
|
|
|
raise Exception
|
2023-09-12 01:04:11 -06:00
|
|
|
|
2023-08-30 18:53:26 -06:00
|
|
|
|
2023-09-12 16:40:09 -06:00
|
|
|
def delete_dict_key(d: dict, k: Union[str, list]):
|
|
|
|
if isinstance(k, str):
|
|
|
|
if k in d.keys():
|
|
|
|
del d[k]
|
|
|
|
elif isinstance(k, list):
|
|
|
|
for item in k:
|
|
|
|
if item in d.keys():
|
|
|
|
del d[item]
|
|
|
|
else:
|
|
|
|
raise ValueError
|
|
|
|
return d
|