130 lines
4.0 KiB
Python
130 lines
4.0 KiB
Python
import json
|
|
import pickle
|
|
import threading
|
|
import time
|
|
from uuid import uuid4
|
|
|
|
from redis import Redis
|
|
|
|
from llm_server import opts
|
|
from llm_server.llm.generator import generator
|
|
from llm_server.routes.cache import redis
|
|
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
|
|
|
|
redis.set_dict('processing_ips', {})
|
|
|
|
|
|
def increment_ip_count(client_ip: int, redis_key):
|
|
ip_count = redis.get_dict(redis_key)
|
|
ip_count[client_ip] = ip_count.get(client_ip, 0) + 1
|
|
redis.set_dict(redis_key, ip_count)
|
|
return ip_count
|
|
|
|
|
|
def decrement_ip_count(client_ip: int, redis_key):
|
|
ip_count = redis.get_dict(redis_key)
|
|
if client_ip in ip_count.keys():
|
|
ip_count[client_ip] -= 1
|
|
if ip_count[client_ip] == 0:
|
|
del ip_count[client_ip] # Remove the IP from the dictionary if count is 0
|
|
redis.set_dict(redis_key, ip_count)
|
|
return ip_count
|
|
|
|
|
|
class RedisPriorityQueue:
|
|
def __init__(self):
|
|
self._index = 0
|
|
self._lock = threading.Lock()
|
|
self.redis = Redis(host='localhost', port=6379, db=15)
|
|
|
|
# Clear the DB
|
|
for key in self.redis.scan_iter('*'):
|
|
self.redis.delete(key)
|
|
|
|
self.pubsub = self.redis.pubsub()
|
|
self.pubsub.subscribe('events')
|
|
|
|
def put(self, item, priority):
|
|
event = DataEvent()
|
|
# Check if the IP is already in the dictionary and if it has reached the limit
|
|
ip_count = self.redis.hget('queued_ip_count', item[1])
|
|
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
|
return None # reject the request
|
|
self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority})
|
|
self._index += 1
|
|
# Increment the count for this IP
|
|
with self._lock:
|
|
self.increment_ip_count(item[1], 'queued_ip_count')
|
|
return event
|
|
|
|
def get(self):
|
|
while True:
|
|
data = self.redis.zpopmin('queue')
|
|
if data:
|
|
item = json.loads(data[0][0])
|
|
client_ip = item[1][1]
|
|
# Decrement the count for this IP
|
|
with self._lock:
|
|
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
|
return item
|
|
time.sleep(1) # wait for an item to be added to the queue
|
|
|
|
def increment_ip_count(self, ip, key):
|
|
self.redis.hincrby(key, ip, 1)
|
|
|
|
def decrement_ip_count(self, ip, key):
|
|
self.redis.hincrby(key, ip, -1)
|
|
|
|
def __len__(self):
|
|
return self.redis.zcard('queue')
|
|
|
|
|
|
class DataEvent:
|
|
def __init__(self, event_id=None):
|
|
self.event_id = event_id if event_id else str(uuid4())
|
|
self.redis = Redis(host='localhost', port=6379, db=14)
|
|
self.pubsub = self.redis.pubsub()
|
|
self.pubsub.subscribe(self.event_id)
|
|
|
|
def set(self, data):
|
|
self.redis.publish(self.event_id, pickle.dumps(data))
|
|
|
|
def wait(self):
|
|
for item in self.pubsub.listen():
|
|
if item['type'] == 'message':
|
|
return pickle.loads(item['data'])
|
|
|
|
|
|
priority_queue = RedisPriorityQueue()
|
|
|
|
|
|
def worker():
|
|
while True:
|
|
index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
|
|
|
increment_ip_count(client_ip, 'processing_ips')
|
|
|
|
# TODO: only increment if not valid SYSTEM__ token
|
|
redis.incr('active_gen_workers')
|
|
|
|
start_time = time.time()
|
|
success, response, error_msg = generator(request_json_body)
|
|
end_time = time.time()
|
|
|
|
elapsed_time = end_time - start_time
|
|
with generation_elapsed_lock:
|
|
generation_elapsed.append((end_time, elapsed_time))
|
|
|
|
event = DataEvent(event_id)
|
|
event.set((success, response, error_msg))
|
|
|
|
decrement_ip_count(client_ip, 'processing_ips')
|
|
|
|
# TODO: only decrement if not valid SYSTEM__ token
|
|
redis.decr('active_gen_workers')
|
|
|
|
|
|
def start_workers(num_workers: int):
|
|
for _ in range(num_workers):
|
|
threading.Thread(target=worker).start()
|