diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 834c844..66659d4 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -1,9 +1,9 @@ -import json import pickle import time from typing import Tuple from uuid import uuid4 +import ujson as json from redis import Redis from llm_server import opts @@ -28,23 +28,22 @@ class RedisPriorityQueue: self.redis = RedisCustom(name, db=db) def put(self, item, priority: int, selected_model: str, do_stream: bool = False): + # TODO: remove this when we're sure nothing strange is happening assert item is not None assert priority is not None assert selected_model is not None event = DataEvent() + # Check if the IP is already in the dictionary and if it has reached the limit - ip_count = self.redis.hget('queued_ip_count', item[1]) - if ip_count: - ip_count = int(ip_count) + ip_count = self.get_ip_request_count(item[1]) _, simultaneous_ip = get_token_ratelimit(item[2]) if ip_count and int(ip_count) >= simultaneous_ip and priority != 0: - print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') + print(f'Rejecting request from {item[1]} - {ip_count} request queued.') return None # reject the request timestamp = time.time() self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority}) - self.increment_ip_count(item[1], 'queued_ip_count') return event def get(self): @@ -52,34 +51,20 @@ class RedisPriorityQueue: data = self.redis.zpopmin('queue') if data: item = json.loads(data[0][0]) - client_ip = item[0][1] - self.decrement_ip_count(client_ip, 'queued_ip_count') return item time.sleep(0.1) # wait for something to be added to the queue - # def print_all_items(self): - # items = self.redis.zrange('queue', 0, -1) - # to_print = [] - # for item in items: - # to_print.append(item.decode('utf-8')) - # print(f'ITEMS {self.name} -->', to_print) - - def increment_ip_count(self, client_ip: str, redis_key): - self.redis.hincrby(redis_key, client_ip, 1) - - def decrement_ip_count(self, client_ip: str, redis_key): - new_count = self.redis.hincrby(redis_key, client_ip, -1) - if new_count <= 0: - self.redis.hdel(redis_key, client_ip) - def __len__(self): return self.redis.zcard('queue') - def get_queued_ip_count(self, client_ip: str): - q = self.redis.hget('queued_ip_count', client_ip) - if not q: - return 0 - return 0 + def get_ip_request_count(self, client_ip: str): + items = self.redis.zrange('queue', 0, -1) + count = 0 + for item in items: + item_data = json.loads(item) + if item_data[0][1] == client_ip: + count += 1 + return count def flush(self): self.redis.flush() @@ -94,10 +79,7 @@ class RedisPriorityQueue: timestamp = item_data[-2] if now - timestamp > opts.backend_generate_request_timeout: self.redis.zrem('queue', 0, item) - data = json.loads(item.decode('utf-8')) - event_id = data[1] - client_ip = data[0][1] - self.decrement_ip_count(client_ip, 'queued_ip_count') + event_id = item_data[1] event = DataEvent(event_id) event.set((False, None, 'closed')) print('Removed timed-out item from queue:', event_id) @@ -114,7 +96,6 @@ class DataEvent: self.redis.publish(self.event_id, pickle.dumps(data)) def wait(self): - # TODO: implement timeout for item in self.pubsub.listen(): if item['type'] == 'message': return pickle.loads(item['data']) @@ -157,7 +138,7 @@ class PriorityQueue: count = 0 for backend_url in self.get_backends(): queue = RedisPriorityQueue(backend_url) - count += queue.get_queued_ip_count(client_ip) + count += queue.get_ip_request_count(client_ip) return count def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False): diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py index fe0d129..c9c421e 100644 --- a/llm_server/workers/printer.py +++ b/llm_server/workers/printer.py @@ -28,4 +28,4 @@ def console_printer(): # Active Workers and Processing should read the same. If not, that's an issue. logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}') - time.sleep(10) + time.sleep(2) diff --git a/server.py b/server.py index e33d55a..43aa9d2 100644 --- a/server.py +++ b/server.py @@ -30,7 +30,7 @@ from llm_server.routes.v1 import bp from llm_server.routes.v1.generate_stats import generate_stats from llm_server.sock import init_wssocket -# TODO: queue item timeout +# TODO: seperate queue item timeout for websockets (make longer, like 5 minutes) # TODO: return an `error: True`, error code, and error message rather than just a formatted message # TODO: what happens when all backends are offline? What about the "online" key in the stats page? # TODO: redis SCAN vs KEYS??