refer to queue for tracking IP count rather than seperate value

This commit is contained in:
Cyberes 2023-10-18 09:03:10 -06:00
parent be03569165
commit 92e4ecd8a1
3 changed files with 17 additions and 36 deletions

View File

@ -1,9 +1,9 @@
import json
import pickle import pickle
import time import time
from typing import Tuple from typing import Tuple
from uuid import uuid4 from uuid import uuid4
import ujson as json
from redis import Redis from redis import Redis
from llm_server import opts from llm_server import opts
@ -28,23 +28,22 @@ class RedisPriorityQueue:
self.redis = RedisCustom(name, db=db) self.redis = RedisCustom(name, db=db)
def put(self, item, priority: int, selected_model: str, do_stream: bool = False): def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
# TODO: remove this when we're sure nothing strange is happening
assert item is not None assert item is not None
assert priority is not None assert priority is not None
assert selected_model is not None assert selected_model is not None
event = DataEvent() event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit # Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1]) ip_count = self.get_ip_request_count(item[1])
if ip_count:
ip_count = int(ip_count)
_, simultaneous_ip = get_token_ratelimit(item[2]) _, simultaneous_ip = get_token_ratelimit(item[2])
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0: if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
return None # reject the request return None # reject the request
timestamp = time.time() timestamp = time.time()
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority}) self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
self.increment_ip_count(item[1], 'queued_ip_count')
return event return event
def get(self): def get(self):
@ -52,34 +51,20 @@ class RedisPriorityQueue:
data = self.redis.zpopmin('queue') data = self.redis.zpopmin('queue')
if data: if data:
item = json.loads(data[0][0]) item = json.loads(data[0][0])
client_ip = item[0][1]
self.decrement_ip_count(client_ip, 'queued_ip_count')
return item return item
time.sleep(0.1) # wait for something to be added to the queue time.sleep(0.1) # wait for something to be added to the queue
# def print_all_items(self):
# items = self.redis.zrange('queue', 0, -1)
# to_print = []
# for item in items:
# to_print.append(item.decode('utf-8'))
# print(f'ITEMS {self.name} -->', to_print)
def increment_ip_count(self, client_ip: str, redis_key):
self.redis.hincrby(redis_key, client_ip, 1)
def decrement_ip_count(self, client_ip: str, redis_key):
new_count = self.redis.hincrby(redis_key, client_ip, -1)
if new_count <= 0:
self.redis.hdel(redis_key, client_ip)
def __len__(self): def __len__(self):
return self.redis.zcard('queue') return self.redis.zcard('queue')
def get_queued_ip_count(self, client_ip: str): def get_ip_request_count(self, client_ip: str):
q = self.redis.hget('queued_ip_count', client_ip) items = self.redis.zrange('queue', 0, -1)
if not q: count = 0
return 0 for item in items:
return 0 item_data = json.loads(item)
if item_data[0][1] == client_ip:
count += 1
return count
def flush(self): def flush(self):
self.redis.flush() self.redis.flush()
@ -94,10 +79,7 @@ class RedisPriorityQueue:
timestamp = item_data[-2] timestamp = item_data[-2]
if now - timestamp > opts.backend_generate_request_timeout: if now - timestamp > opts.backend_generate_request_timeout:
self.redis.zrem('queue', 0, item) self.redis.zrem('queue', 0, item)
data = json.loads(item.decode('utf-8')) event_id = item_data[1]
event_id = data[1]
client_ip = data[0][1]
self.decrement_ip_count(client_ip, 'queued_ip_count')
event = DataEvent(event_id) event = DataEvent(event_id)
event.set((False, None, 'closed')) event.set((False, None, 'closed'))
print('Removed timed-out item from queue:', event_id) print('Removed timed-out item from queue:', event_id)
@ -114,7 +96,6 @@ class DataEvent:
self.redis.publish(self.event_id, pickle.dumps(data)) self.redis.publish(self.event_id, pickle.dumps(data))
def wait(self): def wait(self):
# TODO: implement timeout
for item in self.pubsub.listen(): for item in self.pubsub.listen():
if item['type'] == 'message': if item['type'] == 'message':
return pickle.loads(item['data']) return pickle.loads(item['data'])
@ -157,7 +138,7 @@ class PriorityQueue:
count = 0 count = 0
for backend_url in self.get_backends(): for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url) queue = RedisPriorityQueue(backend_url)
count += queue.get_queued_ip_count(client_ip) count += queue.get_ip_request_count(client_ip)
return count return count
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False): def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):

View File

@ -28,4 +28,4 @@ def console_printer():
# Active Workers and Processing should read the same. If not, that's an issue. # Active Workers and Processing should read the same. If not, that's an issue.
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}') logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
time.sleep(10) time.sleep(2)

View File

@ -30,7 +30,7 @@ from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.sock import init_wssocket from llm_server.sock import init_wssocket
# TODO: queue item timeout # TODO: seperate queue item timeout for websockets (make longer, like 5 minutes)
# TODO: return an `error: True`, error code, and error message rather than just a formatted message # TODO: return an `error: True`, error code, and error message rather than just a formatted message
# TODO: what happens when all backends are offline? What about the "online" key in the stats page? # TODO: what happens when all backends are offline? What about the "online" key in the stats page?
# TODO: redis SCAN vs KEYS?? # TODO: redis SCAN vs KEYS??