local-llm-server/llm_server/routes/queue.py

127 lines
3.9 KiB
Python

import json
import pickle
import threading
import time
from uuid import uuid4
from redis import Redis
from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
redis.set_dict('processing_ips', {})
def increment_ip_count(client_ip: int, redis_key):
ip_count = redis.get_dict(redis_key)
ip_count[client_ip] = ip_count.get(client_ip, 0) + 1
redis.set_dict(redis_key, ip_count)
return ip_count
def decrement_ip_count(client_ip: int, redis_key):
ip_count = redis.get_dict(redis_key)
if client_ip in ip_count.keys():
ip_count[client_ip] -= 1
if ip_count[client_ip] == 0:
del ip_count[client_ip] # Remove the IP from the dictionary if count is 0
redis.set_dict(redis_key, ip_count)
return ip_count
class RedisPriorityQueue:
def __init__(self):
self._index = 0
self._lock = threading.Lock()
self.redis = Redis(host='localhost', port=6379, db=15)
# Clear the DB
for key in self.redis.scan_iter('*'):
self.redis.delete(key)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe('events')
def put(self, item, priority):
event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
return None # reject the request
self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority})
self._index += 1
# Increment the count for this IP
with self._lock:
self.increment_ip_count(item[1], 'queued_ip_count')
return event
def get(self):
while True:
data = self.redis.zpopmin('queue')
if data:
item = json.loads(data[0][0])
client_ip = item[1][1]
# Decrement the count for this IP
with self._lock:
self.decrement_ip_count(client_ip, 'queued_ip_count')
return item
time.sleep(1) # wait for an item to be added to the queue
def increment_ip_count(self, ip, key):
self.redis.hincrby(key, ip, 1)
def decrement_ip_count(self, ip, key):
self.redis.hincrby(key, ip, -1)
def __len__(self):
return self.redis.zcard('queue')
class DataEvent:
def __init__(self, event_id=None):
self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe(self.event_id)
def set(self, data):
self.redis.publish(self.event_id, pickle.dumps(data))
def wait(self):
for item in self.pubsub.listen():
if item['type'] == 'message':
return pickle.loads(item['data'])
priority_queue = RedisPriorityQueue()
def worker():
while True:
index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
increment_ip_count(client_ip, 'processing_ips')
redis.incr('active_gen_workers')
start_time = time.time()
success, response, error_msg = generator(request_json_body)
end_time = time.time()
elapsed_time = end_time - start_time
with generation_elapsed_lock:
generation_elapsed.append((end_time, elapsed_time))
event = DataEvent(event_id)
event.set((success, response, error_msg))
decrement_ip_count(client_ip, 'processing_ips')
redis.decr('active_gen_workers')
def start_workers(num_workers: int):
for _ in range(num_workers):
threading.Thread(target=worker).start()