fix ratelimiting
This commit is contained in:
parent
d1c4e68f8b
commit
b0089859d7
|
@ -188,3 +188,21 @@ def increment_token_uses(token):
|
|||
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def get_token_ratelimit(token):
|
||||
priority = 9990
|
||||
simultaneous_ip = opts.simultaneous_requests_per_ip
|
||||
if token:
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
priority, simultaneous_ip = result
|
||||
if simultaneous_ip is None:
|
||||
# No ratelimit for this token if null
|
||||
simultaneous_ip = 999999999
|
||||
finally:
|
||||
cursor.close()
|
||||
return priority, simultaneous_ip
|
||||
|
|
|
@ -25,6 +25,7 @@ def tokenize(prompt: str, backend_url: str) -> int:
|
|||
return j['length']
|
||||
except Exception as e:
|
||||
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
|
||||
raise Exception
|
||||
return len(tokenizer.encode(chunk)) + 10
|
||||
|
||||
# Use a ThreadPoolExecutor to send all chunks to the server at once
|
||||
|
|
|
@ -5,8 +5,8 @@ from uuid import uuid4
|
|||
|
||||
from redis import Redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
from llm_server.database.database import get_token_ratelimit
|
||||
|
||||
|
||||
def increment_ip_count(client_ip: str, redis_key):
|
||||
|
@ -32,7 +32,8 @@ class RedisPriorityQueue:
|
|||
ip_count = self.redis.hget('queued_ip_count', item[1])
|
||||
if ip_count:
|
||||
ip_count = int(ip_count)
|
||||
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
||||
_, simultaneous_ip = get_token_ratelimit(item[2])
|
||||
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
|
||||
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
||||
return None # reject the request
|
||||
|
||||
|
|
|
@ -8,8 +8,7 @@ from llm_server import opts
|
|||
from llm_server.cluster.backend import get_a_cluster_backend
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.database.database import get_token_ratelimit, log_prompt
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
|
@ -17,8 +16,6 @@ from llm_server.routes.auth import parse_token
|
|||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
|
||||
DEFAULT_PRIORITY = 9999
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
||||
|
@ -36,7 +33,7 @@ class RequestHandler:
|
|||
self.start_time = time.time()
|
||||
self.client_ip = self.get_client_ip()
|
||||
self.token = self.get_auth_token()
|
||||
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
|
||||
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
|
||||
self.backend_url = get_a_cluster_backend(selected_model)
|
||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||
|
||||
|
@ -58,6 +55,8 @@ class RequestHandler:
|
|||
return parse_token(self.request.headers['Authorization'])
|
||||
|
||||
def get_client_ip(self):
|
||||
if self.request.headers.get('Llm-Connecting-Ip'):
|
||||
return self.request.headers['Llm-Connecting-Ip']
|
||||
if self.request.headers.get('X-Connecting-IP'):
|
||||
return self.request.headers.get('X-Connecting-IP')
|
||||
elif self.request.headers.get('Cf-Connecting-Ip'):
|
||||
|
@ -67,23 +66,6 @@ class RequestHandler:
|
|||
else:
|
||||
return self.request.remote_addr
|
||||
|
||||
def get_token_ratelimit(self):
|
||||
priority = DEFAULT_PRIORITY
|
||||
simultaneous_ip = opts.simultaneous_requests_per_ip
|
||||
if self.token:
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
priority, simultaneous_ip = result
|
||||
if simultaneous_ip is None:
|
||||
# No ratelimit for this token if null
|
||||
simultaneous_ip = 999999999
|
||||
finally:
|
||||
cursor.close()
|
||||
return priority, simultaneous_ip
|
||||
|
||||
def get_parameters(self):
|
||||
if self.request_json_body.get('max_tokens'):
|
||||
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
||||
|
@ -210,17 +192,21 @@ class RequestHandler:
|
|||
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||
|
||||
def is_client_ratelimited(self) -> bool:
|
||||
if self.token_priority == 0:
|
||||
return False
|
||||
|
||||
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
|
||||
x = redis.hget('processing_ips', self.client_ip)
|
||||
if x:
|
||||
processing_ip = int(x)
|
||||
else:
|
||||
processing_ip = 0
|
||||
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
|
||||
return False
|
||||
else:
|
||||
|
||||
if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
|
||||
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.')
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
# Must include this in your child.
|
||||
|
|
Reference in New Issue