fix ratelimiting
This commit is contained in:
parent
d1c4e68f8b
commit
b0089859d7
|
@ -188,3 +188,21 @@ def increment_token_uses(token):
|
||||||
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
|
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
|
||||||
finally:
|
finally:
|
||||||
cursor.close()
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
|
def get_token_ratelimit(token):
|
||||||
|
priority = 9990
|
||||||
|
simultaneous_ip = opts.simultaneous_requests_per_ip
|
||||||
|
if token:
|
||||||
|
cursor = database.cursor()
|
||||||
|
try:
|
||||||
|
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
priority, simultaneous_ip = result
|
||||||
|
if simultaneous_ip is None:
|
||||||
|
# No ratelimit for this token if null
|
||||||
|
simultaneous_ip = 999999999
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
return priority, simultaneous_ip
|
||||||
|
|
|
@ -25,6 +25,7 @@ def tokenize(prompt: str, backend_url: str) -> int:
|
||||||
return j['length']
|
return j['length']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
|
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
|
||||||
|
raise Exception
|
||||||
return len(tokenizer.encode(chunk)) + 10
|
return len(tokenizer.encode(chunk)) + 10
|
||||||
|
|
||||||
# Use a ThreadPoolExecutor to send all chunks to the server at once
|
# Use a ThreadPoolExecutor to send all chunks to the server at once
|
||||||
|
|
|
@ -5,8 +5,8 @@ from uuid import uuid4
|
||||||
|
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
from llm_server import opts
|
|
||||||
from llm_server.custom_redis import RedisCustom, redis
|
from llm_server.custom_redis import RedisCustom, redis
|
||||||
|
from llm_server.database.database import get_token_ratelimit
|
||||||
|
|
||||||
|
|
||||||
def increment_ip_count(client_ip: str, redis_key):
|
def increment_ip_count(client_ip: str, redis_key):
|
||||||
|
@ -32,7 +32,8 @@ class RedisPriorityQueue:
|
||||||
ip_count = self.redis.hget('queued_ip_count', item[1])
|
ip_count = self.redis.hget('queued_ip_count', item[1])
|
||||||
if ip_count:
|
if ip_count:
|
||||||
ip_count = int(ip_count)
|
ip_count = int(ip_count)
|
||||||
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
_, simultaneous_ip = get_token_ratelimit(item[2])
|
||||||
|
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
|
||||||
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
||||||
return None # reject the request
|
return None # reject the request
|
||||||
|
|
||||||
|
|
|
@ -8,8 +8,7 @@ from llm_server import opts
|
||||||
from llm_server.cluster.backend import get_a_cluster_backend
|
from llm_server.cluster.backend import get_a_cluster_backend
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.database.conn import database
|
from llm_server.database.database import get_token_ratelimit, log_prompt
|
||||||
from llm_server.database.database import log_prompt
|
|
||||||
from llm_server.helpers import auto_set_base_client_api
|
from llm_server.helpers import auto_set_base_client_api
|
||||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||||
|
@ -17,8 +16,6 @@ from llm_server.routes.auth import parse_token
|
||||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||||
from llm_server.routes.queue import priority_queue
|
from llm_server.routes.queue import priority_queue
|
||||||
|
|
||||||
DEFAULT_PRIORITY = 9999
|
|
||||||
|
|
||||||
|
|
||||||
class RequestHandler:
|
class RequestHandler:
|
||||||
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
||||||
|
@ -36,7 +33,7 @@ class RequestHandler:
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
self.client_ip = self.get_client_ip()
|
self.client_ip = self.get_client_ip()
|
||||||
self.token = self.get_auth_token()
|
self.token = self.get_auth_token()
|
||||||
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
|
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
|
||||||
self.backend_url = get_a_cluster_backend(selected_model)
|
self.backend_url = get_a_cluster_backend(selected_model)
|
||||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||||
|
|
||||||
|
@ -58,6 +55,8 @@ class RequestHandler:
|
||||||
return parse_token(self.request.headers['Authorization'])
|
return parse_token(self.request.headers['Authorization'])
|
||||||
|
|
||||||
def get_client_ip(self):
|
def get_client_ip(self):
|
||||||
|
if self.request.headers.get('Llm-Connecting-Ip'):
|
||||||
|
return self.request.headers['Llm-Connecting-Ip']
|
||||||
if self.request.headers.get('X-Connecting-IP'):
|
if self.request.headers.get('X-Connecting-IP'):
|
||||||
return self.request.headers.get('X-Connecting-IP')
|
return self.request.headers.get('X-Connecting-IP')
|
||||||
elif self.request.headers.get('Cf-Connecting-Ip'):
|
elif self.request.headers.get('Cf-Connecting-Ip'):
|
||||||
|
@ -67,23 +66,6 @@ class RequestHandler:
|
||||||
else:
|
else:
|
||||||
return self.request.remote_addr
|
return self.request.remote_addr
|
||||||
|
|
||||||
def get_token_ratelimit(self):
|
|
||||||
priority = DEFAULT_PRIORITY
|
|
||||||
simultaneous_ip = opts.simultaneous_requests_per_ip
|
|
||||||
if self.token:
|
|
||||||
cursor = database.cursor()
|
|
||||||
try:
|
|
||||||
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
|
|
||||||
result = cursor.fetchone()
|
|
||||||
if result:
|
|
||||||
priority, simultaneous_ip = result
|
|
||||||
if simultaneous_ip is None:
|
|
||||||
# No ratelimit for this token if null
|
|
||||||
simultaneous_ip = 999999999
|
|
||||||
finally:
|
|
||||||
cursor.close()
|
|
||||||
return priority, simultaneous_ip
|
|
||||||
|
|
||||||
def get_parameters(self):
|
def get_parameters(self):
|
||||||
if self.request_json_body.get('max_tokens'):
|
if self.request_json_body.get('max_tokens'):
|
||||||
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
||||||
|
@ -210,17 +192,21 @@ class RequestHandler:
|
||||||
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||||
|
|
||||||
def is_client_ratelimited(self) -> bool:
|
def is_client_ratelimited(self) -> bool:
|
||||||
|
if self.token_priority == 0:
|
||||||
|
return False
|
||||||
|
|
||||||
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
|
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
|
||||||
x = redis.hget('processing_ips', self.client_ip)
|
x = redis.hget('processing_ips', self.client_ip)
|
||||||
if x:
|
if x:
|
||||||
processing_ip = int(x)
|
processing_ip = int(x)
|
||||||
else:
|
else:
|
||||||
processing_ip = 0
|
processing_ip = 0
|
||||||
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
|
|
||||||
return False
|
if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
|
||||||
else:
|
|
||||||
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.')
|
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.')
|
||||||
return True
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||||
# Must include this in your child.
|
# Must include this in your child.
|
||||||
|
|
Reference in New Issue