99 lines
2.9 KiB
Python
99 lines
2.9 KiB
Python
import json
|
|
import pickle
|
|
import time
|
|
from uuid import uuid4
|
|
|
|
from redis import Redis
|
|
|
|
from llm_server import opts
|
|
from llm_server.routes.cache import redis
|
|
|
|
|
|
def increment_ip_count(client_ip: str, redis_key):
|
|
redis.hincrby(redis_key, client_ip, 1)
|
|
|
|
|
|
def decrement_ip_count(client_ip: str, redis_key):
|
|
new_count = redis.hincrby(redis_key, client_ip, -1)
|
|
if new_count <= 0:
|
|
redis.hdel(redis_key, client_ip)
|
|
|
|
|
|
class RedisPriorityQueue:
|
|
def __init__(self):
|
|
self.redis = Redis(host='localhost', port=6379, db=15)
|
|
self.pubsub = self.redis.pubsub()
|
|
self.pubsub.subscribe('events')
|
|
|
|
def put(self, item, priority):
|
|
event = DataEvent()
|
|
|
|
# Check if the IP is already in the dictionary and if it has reached the limit
|
|
ip_count = self.redis.hget('queued_ip_count', item[1])
|
|
if ip_count:
|
|
ip_count = int(ip_count)
|
|
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
|
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
|
return None # reject the request
|
|
|
|
self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority})
|
|
self.increment_ip_count(item[1], 'queued_ip_count')
|
|
return event
|
|
|
|
def get(self):
|
|
while True:
|
|
data = self.redis.zpopmin('queue')
|
|
if data:
|
|
item = json.loads(data[0][0])
|
|
client_ip = item[0][1]
|
|
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
|
return item
|
|
time.sleep(0.1) # wait for something to be added to the queue
|
|
|
|
def increment_ip_count(self, client_ip: str, redis_key):
|
|
self.redis.hincrby(redis_key, client_ip, 1)
|
|
|
|
def decrement_ip_count(self, client_ip: str, redis_key):
|
|
new_count = self.redis.hincrby(redis_key, client_ip, -1)
|
|
if new_count <= 0:
|
|
self.redis.hdel(redis_key, client_ip)
|
|
|
|
def __len__(self):
|
|
return self.redis.zcard('queue')
|
|
|
|
def get_queued_ip_count(self, client_ip: str):
|
|
q = self.redis.hget('queued_ip_count', client_ip)
|
|
if not q:
|
|
return 0
|
|
return 0
|
|
|
|
|
|
class DataEvent:
|
|
def __init__(self, event_id=None):
|
|
self.event_id = event_id if event_id else str(uuid4())
|
|
self.redis = Redis(host='localhost', port=6379, db=14)
|
|
self.pubsub = self.redis.pubsub()
|
|
self.pubsub.subscribe(self.event_id)
|
|
|
|
def set(self, data):
|
|
self.redis.publish(self.event_id, pickle.dumps(data))
|
|
|
|
def wait(self):
|
|
for item in self.pubsub.listen():
|
|
if item['type'] == 'message':
|
|
return pickle.loads(item['data'])
|
|
|
|
|
|
priority_queue = RedisPriorityQueue()
|
|
|
|
|
|
def incr_active_workers():
|
|
redis.incr('active_gen_workers')
|
|
|
|
|
|
def decr_active_workers():
|
|
redis.decr('active_gen_workers')
|
|
new_count = redis.get('active_gen_workers', int, 0)
|
|
if new_count < 0:
|
|
redis.set('active_gen_workers', 0)
|