This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/routes/queue.py

114 lines
3.6 KiB
Python

import json
import pickle
import time
from uuid import uuid4
from redis import Redis
from llm_server import opts
from llm_server.custom_redis import RedisCustom, redis
def increment_ip_count(client_ip: str, redis_key):
redis.hincrby(redis_key, client_ip, 1)
def decrement_ip_count(client_ip: str, redis_key):
new_count = redis.hincrby(redis_key, client_ip, -1)
if new_count <= 0:
redis.hdel(redis_key, client_ip)
class RedisPriorityQueue:
def __init__(self, name: str = 'priority_queue', db: int = 12):
self.redis = RedisCustom(name, db=db)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe('events')
def put(self, item, priority, selected_model):
event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count:
ip_count = int(ip_count)
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
return None # reject the request
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority})
self.increment_ip_count(item[1], 'queued_ip_count')
return event
def get(self):
while True:
data = self.redis.zpopmin('queue')
if data:
item = json.loads(data[0][0])
client_ip = item[0][1]
self.decrement_ip_count(client_ip, 'queued_ip_count')
return item
time.sleep(0.1) # wait for something to be added to the queue
def increment_ip_count(self, client_ip: str, redis_key):
self.redis.hincrby(redis_key, client_ip, 1)
def decrement_ip_count(self, client_ip: str, redis_key):
new_count = self.redis.hincrby(redis_key, client_ip, -1)
if new_count <= 0:
self.redis.hdel(redis_key, client_ip)
def __len__(self):
return self.redis.zcard('queue')
def len(self, model_name):
count = 0
for key in self.redis.zrange('queue', 0, -1):
item = json.loads(key)
if item[2] == model_name:
count += 1
return count
def get_queued_ip_count(self, client_ip: str):
q = self.redis.hget('queued_ip_count', client_ip)
if not q:
return 0
return 0
def flush(self):
self.redis.flush()
class DataEvent:
def __init__(self, event_id=None):
self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe(self.event_id)
def set(self, data):
self.redis.publish(self.event_id, pickle.dumps(data))
def wait(self):
for item in self.pubsub.listen():
if item['type'] == 'message':
return pickle.loads(item['data'])
priority_queue = RedisPriorityQueue()
def incr_active_workers(selected_model: str, backend_url: str):
redis.incr(f'active_gen_workers:{selected_model}')
redis.incr(f'active_gen_workers:{backend_url}')
def decr_active_workers(selected_model: str, backend_url: str):
redis.decr(f'active_gen_workers:{selected_model}')
if redis.get(f'active_gen_workers:{selected_model}', 0, dtype=int) < 0:
redis.set(f'active_gen_workers:{selected_model}', 0)
redis.decr(f'active_gen_workers:{backend_url}')
if redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) < 0:
redis.set(f'active_gen_workers:{backend_url}', 0)