fix ratelimiting

This commit is contained in:
Cyberes 2023-10-02 02:05:15 -06:00
parent d1c4e68f8b
commit b0089859d7
4 changed files with 33 additions and 27 deletions

View File

@ -188,3 +188,21 @@ def increment_token_uses(token):
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,)) cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
finally: finally:
cursor.close() cursor.close()
def get_token_ratelimit(token):
priority = 9990
simultaneous_ip = opts.simultaneous_requests_per_ip
if token:
cursor = database.cursor()
try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
result = cursor.fetchone()
if result:
priority, simultaneous_ip = result
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip

View File

@ -25,6 +25,7 @@ def tokenize(prompt: str, backend_url: str) -> int:
return j['length'] return j['length']
except Exception as e: except Exception as e:
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
raise Exception
return len(tokenizer.encode(chunk)) + 10 return len(tokenizer.encode(chunk)) + 10
# Use a ThreadPoolExecutor to send all chunks to the server at once # Use a ThreadPoolExecutor to send all chunks to the server at once

View File

@ -5,8 +5,8 @@ from uuid import uuid4
from redis import Redis from redis import Redis
from llm_server import opts
from llm_server.custom_redis import RedisCustom, redis from llm_server.custom_redis import RedisCustom, redis
from llm_server.database.database import get_token_ratelimit
def increment_ip_count(client_ip: str, redis_key): def increment_ip_count(client_ip: str, redis_key):
@ -32,7 +32,8 @@ class RedisPriorityQueue:
ip_count = self.redis.hget('queued_ip_count', item[1]) ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count: if ip_count:
ip_count = int(ip_count) ip_count = int(ip_count)
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: _, simultaneous_ip = get_token_ratelimit(item[2])
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
return None # reject the request return None # reject the request

View File

@ -8,8 +8,7 @@ from llm_server import opts
from llm_server.cluster.backend import get_a_cluster_backend from llm_server.cluster.backend import get_a_cluster_backend
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.conn import database from llm_server.database.database import get_token_ratelimit, log_prompt
from llm_server.database.database import log_prompt
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend
@ -17,8 +16,6 @@ from llm_server.routes.auth import parse_token
from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
DEFAULT_PRIORITY = 9999
class RequestHandler: class RequestHandler:
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None): def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
@ -36,7 +33,7 @@ class RequestHandler:
self.start_time = time.time() self.start_time = time.time()
self.client_ip = self.get_client_ip() self.client_ip = self.get_client_ip()
self.token = self.get_auth_token() self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit() self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
self.backend_url = get_a_cluster_backend(selected_model) self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url) self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
@ -58,6 +55,8 @@ class RequestHandler:
return parse_token(self.request.headers['Authorization']) return parse_token(self.request.headers['Authorization'])
def get_client_ip(self): def get_client_ip(self):
if self.request.headers.get('Llm-Connecting-Ip'):
return self.request.headers['Llm-Connecting-Ip']
if self.request.headers.get('X-Connecting-IP'): if self.request.headers.get('X-Connecting-IP'):
return self.request.headers.get('X-Connecting-IP') return self.request.headers.get('X-Connecting-IP')
elif self.request.headers.get('Cf-Connecting-Ip'): elif self.request.headers.get('Cf-Connecting-Ip'):
@ -67,23 +66,6 @@ class RequestHandler:
else: else:
return self.request.remote_addr return self.request.remote_addr
def get_token_ratelimit(self):
priority = DEFAULT_PRIORITY
simultaneous_ip = opts.simultaneous_requests_per_ip
if self.token:
cursor = database.cursor()
try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
result = cursor.fetchone()
if result:
priority, simultaneous_ip = result
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip
def get_parameters(self): def get_parameters(self):
if self.request_json_body.get('max_tokens'): if self.request_json_body.get('max_tokens'):
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
@ -210,17 +192,21 @@ class RequestHandler:
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def is_client_ratelimited(self) -> bool: def is_client_ratelimited(self) -> bool:
if self.token_priority == 0:
return False
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip)) queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
x = redis.hget('processing_ips', self.client_ip) x = redis.hget('processing_ips', self.client_ip)
if x: if x:
processing_ip = int(x) processing_ip = int(x)
else: else:
processing_ip = 0 processing_ip = 0
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
return False if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
else:
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.') print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.')
return True return True
else:
return False
def handle_request(self) -> Tuple[flask.Response, int]: def handle_request(self) -> Tuple[flask.Response, int]:
# Must include this in your child. # Must include this in your child.