refer to queue for tracking IP count rather than seperate value
This commit is contained in:
parent
be03569165
commit
92e4ecd8a1
|
@ -1,9 +1,9 @@
|
||||||
import json
|
|
||||||
import pickle
|
import pickle
|
||||||
import time
|
import time
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import ujson as json
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
|
@ -28,23 +28,22 @@ class RedisPriorityQueue:
|
||||||
self.redis = RedisCustom(name, db=db)
|
self.redis = RedisCustom(name, db=db)
|
||||||
|
|
||||||
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
|
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
|
||||||
|
# TODO: remove this when we're sure nothing strange is happening
|
||||||
assert item is not None
|
assert item is not None
|
||||||
assert priority is not None
|
assert priority is not None
|
||||||
assert selected_model is not None
|
assert selected_model is not None
|
||||||
|
|
||||||
event = DataEvent()
|
event = DataEvent()
|
||||||
|
|
||||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||||
ip_count = self.redis.hget('queued_ip_count', item[1])
|
ip_count = self.get_ip_request_count(item[1])
|
||||||
if ip_count:
|
|
||||||
ip_count = int(ip_count)
|
|
||||||
_, simultaneous_ip = get_token_ratelimit(item[2])
|
_, simultaneous_ip = get_token_ratelimit(item[2])
|
||||||
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
|
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
|
||||||
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
|
||||||
return None # reject the request
|
return None # reject the request
|
||||||
|
|
||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
|
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
|
||||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
def get(self):
|
def get(self):
|
||||||
|
@ -52,34 +51,20 @@ class RedisPriorityQueue:
|
||||||
data = self.redis.zpopmin('queue')
|
data = self.redis.zpopmin('queue')
|
||||||
if data:
|
if data:
|
||||||
item = json.loads(data[0][0])
|
item = json.loads(data[0][0])
|
||||||
client_ip = item[0][1]
|
|
||||||
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
|
||||||
return item
|
return item
|
||||||
time.sleep(0.1) # wait for something to be added to the queue
|
time.sleep(0.1) # wait for something to be added to the queue
|
||||||
|
|
||||||
# def print_all_items(self):
|
|
||||||
# items = self.redis.zrange('queue', 0, -1)
|
|
||||||
# to_print = []
|
|
||||||
# for item in items:
|
|
||||||
# to_print.append(item.decode('utf-8'))
|
|
||||||
# print(f'ITEMS {self.name} -->', to_print)
|
|
||||||
|
|
||||||
def increment_ip_count(self, client_ip: str, redis_key):
|
|
||||||
self.redis.hincrby(redis_key, client_ip, 1)
|
|
||||||
|
|
||||||
def decrement_ip_count(self, client_ip: str, redis_key):
|
|
||||||
new_count = self.redis.hincrby(redis_key, client_ip, -1)
|
|
||||||
if new_count <= 0:
|
|
||||||
self.redis.hdel(redis_key, client_ip)
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.redis.zcard('queue')
|
return self.redis.zcard('queue')
|
||||||
|
|
||||||
def get_queued_ip_count(self, client_ip: str):
|
def get_ip_request_count(self, client_ip: str):
|
||||||
q = self.redis.hget('queued_ip_count', client_ip)
|
items = self.redis.zrange('queue', 0, -1)
|
||||||
if not q:
|
count = 0
|
||||||
return 0
|
for item in items:
|
||||||
return 0
|
item_data = json.loads(item)
|
||||||
|
if item_data[0][1] == client_ip:
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
self.redis.flush()
|
self.redis.flush()
|
||||||
|
@ -94,10 +79,7 @@ class RedisPriorityQueue:
|
||||||
timestamp = item_data[-2]
|
timestamp = item_data[-2]
|
||||||
if now - timestamp > opts.backend_generate_request_timeout:
|
if now - timestamp > opts.backend_generate_request_timeout:
|
||||||
self.redis.zrem('queue', 0, item)
|
self.redis.zrem('queue', 0, item)
|
||||||
data = json.loads(item.decode('utf-8'))
|
event_id = item_data[1]
|
||||||
event_id = data[1]
|
|
||||||
client_ip = data[0][1]
|
|
||||||
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
|
||||||
event = DataEvent(event_id)
|
event = DataEvent(event_id)
|
||||||
event.set((False, None, 'closed'))
|
event.set((False, None, 'closed'))
|
||||||
print('Removed timed-out item from queue:', event_id)
|
print('Removed timed-out item from queue:', event_id)
|
||||||
|
@ -114,7 +96,6 @@ class DataEvent:
|
||||||
self.redis.publish(self.event_id, pickle.dumps(data))
|
self.redis.publish(self.event_id, pickle.dumps(data))
|
||||||
|
|
||||||
def wait(self):
|
def wait(self):
|
||||||
# TODO: implement timeout
|
|
||||||
for item in self.pubsub.listen():
|
for item in self.pubsub.listen():
|
||||||
if item['type'] == 'message':
|
if item['type'] == 'message':
|
||||||
return pickle.loads(item['data'])
|
return pickle.loads(item['data'])
|
||||||
|
@ -157,7 +138,7 @@ class PriorityQueue:
|
||||||
count = 0
|
count = 0
|
||||||
for backend_url in self.get_backends():
|
for backend_url in self.get_backends():
|
||||||
queue = RedisPriorityQueue(backend_url)
|
queue = RedisPriorityQueue(backend_url)
|
||||||
count += queue.get_queued_ip_count(client_ip)
|
count += queue.get_ip_request_count(client_ip)
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
|
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
|
||||||
|
|
|
@ -28,4 +28,4 @@ def console_printer():
|
||||||
|
|
||||||
# Active Workers and Processing should read the same. If not, that's an issue.
|
# Active Workers and Processing should read the same. If not, that's an issue.
|
||||||
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
|
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
|
||||||
time.sleep(10)
|
time.sleep(2)
|
||||||
|
|
|
@ -30,7 +30,7 @@ from llm_server.routes.v1 import bp
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.sock import init_wssocket
|
from llm_server.sock import init_wssocket
|
||||||
|
|
||||||
# TODO: queue item timeout
|
# TODO: seperate queue item timeout for websockets (make longer, like 5 minutes)
|
||||||
# TODO: return an `error: True`, error code, and error message rather than just a formatted message
|
# TODO: return an `error: True`, error code, and error message rather than just a formatted message
|
||||||
# TODO: what happens when all backends are offline? What about the "online" key in the stats page?
|
# TODO: what happens when all backends are offline? What about the "online" key in the stats page?
|
||||||
# TODO: redis SCAN vs KEYS??
|
# TODO: redis SCAN vs KEYS??
|
||||||
|
|
Reference in New Issue