diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 1dc2145..dec5e98 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -188,3 +188,21 @@ def increment_token_uses(token): cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,)) finally: cursor.close() + + +def get_token_ratelimit(token): + priority = 9990 + simultaneous_ip = opts.simultaneous_requests_per_ip + if token: + cursor = database.cursor() + try: + cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,)) + result = cursor.fetchone() + if result: + priority, simultaneous_ip = result + if simultaneous_ip is None: + # No ratelimit for this token if null + simultaneous_ip = 999999999 + finally: + cursor.close() + return priority, simultaneous_ip diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index 006842e..bd44ad8 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -25,6 +25,7 @@ def tokenize(prompt: str, backend_url: str) -> int: return j['length'] except Exception as e: print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') + raise Exception return len(tokenizer.encode(chunk)) + 10 # Use a ThreadPoolExecutor to send all chunks to the server at once diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index f058298..5d2c6b3 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -5,8 +5,8 @@ from uuid import uuid4 from redis import Redis -from llm_server import opts from llm_server.custom_redis import RedisCustom, redis +from llm_server.database.database import get_token_ratelimit def increment_ip_count(client_ip: str, redis_key): @@ -32,7 +32,8 @@ class RedisPriorityQueue: ip_count = self.redis.hget('queued_ip_count', item[1]) if ip_count: ip_count = int(ip_count) - if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: + _, simultaneous_ip = get_token_ratelimit(item[2]) + if ip_count and int(ip_count) >= simultaneous_ip and priority != 0: print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') return None # reject the request diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index f93547b..7c425dc 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -8,8 +8,7 @@ from llm_server import opts from llm_server.cluster.backend import get_a_cluster_backend from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis -from llm_server.database.conn import database -from llm_server.database.database import log_prompt +from llm_server.database.database import get_token_ratelimit, log_prompt from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend @@ -17,8 +16,6 @@ from llm_server.routes.auth import parse_token from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue -DEFAULT_PRIORITY = 9999 - class RequestHandler: def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None): @@ -36,7 +33,7 @@ class RequestHandler: self.start_time = time.time() self.client_ip = self.get_client_ip() self.token = self.get_auth_token() - self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit() + self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token) self.backend_url = get_a_cluster_backend(selected_model) self.cluster_backend_info = cluster_config.get_backend(self.backend_url) @@ -58,6 +55,8 @@ class RequestHandler: return parse_token(self.request.headers['Authorization']) def get_client_ip(self): + if self.request.headers.get('Llm-Connecting-Ip'): + return self.request.headers['Llm-Connecting-Ip'] if self.request.headers.get('X-Connecting-IP'): return self.request.headers.get('X-Connecting-IP') elif self.request.headers.get('Cf-Connecting-Ip'): @@ -67,23 +66,6 @@ class RequestHandler: else: return self.request.remote_addr - def get_token_ratelimit(self): - priority = DEFAULT_PRIORITY - simultaneous_ip = opts.simultaneous_requests_per_ip - if self.token: - cursor = database.cursor() - try: - cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,)) - result = cursor.fetchone() - if result: - priority, simultaneous_ip = result - if simultaneous_ip is None: - # No ratelimit for this token if null - simultaneous_ip = 999999999 - finally: - cursor.close() - return priority, simultaneous_ip - def get_parameters(self): if self.request_json_body.get('max_tokens'): self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') @@ -210,17 +192,21 @@ class RequestHandler: return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) def is_client_ratelimited(self) -> bool: + if self.token_priority == 0: + return False + queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip)) x = redis.hget('processing_ips', self.client_ip) if x: processing_ip = int(x) else: processing_ip = 0 - if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0: - return False - else: + + if queued_ip_count + processing_ip >= self.token_simultaneous_ip: print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} already queued/processing.') return True + else: + return False def handle_request(self) -> Tuple[flask.Response, int]: # Must include this in your child.