This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.

130 lines
4.0 KiB

import json
import pickle
import threading
import time
from uuid import uuid4
from redis import Redis
from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
redis.set_dict('processing_ips', {})
def increment_ip_count(client_ip: int, redis_key):
ip_count = redis.get_dict(redis_key)
ip_count[client_ip] = ip_count.get(client_ip, 0) + 1
redis.set_dict(redis_key, ip_count)
return ip_count
def decrement_ip_count(client_ip: int, redis_key):
ip_count = redis.get_dict(redis_key)
if client_ip in ip_count.keys():
ip_count[client_ip] -= 1
if ip_count[client_ip] == 0:
del ip_count[client_ip] # Remove the IP from the dictionary if count is 0
redis.set_dict(redis_key, ip_count)
return ip_count
class RedisPriorityQueue:
def __init__(self):
self._index = 0
self._lock = threading.Lock()
self.redis = Redis(host='localhost', port=6379, db=15)
# Clear the DB
for key in self.redis.scan_iter('*'):
self.pubsub = self.redis.pubsub()
def put(self, item, priority):
event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
return None # reject the request
self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority})
self._index += 1
# Increment the count for this IP
with self._lock:
self.increment_ip_count(item[1], 'queued_ip_count')
return event
def get(self):
while True:
data = self.redis.zpopmin('queue')
if data:
item = json.loads(data[0][0])
client_ip = item[1][1]
# Decrement the count for this IP
with self._lock:
self.decrement_ip_count(client_ip, 'queued_ip_count')
return item
time.sleep(1) # wait for an item to be added to the queue
def increment_ip_count(self, ip, key):
self.redis.hincrby(key, ip, 1)
def decrement_ip_count(self, ip, key):
self.redis.hincrby(key, ip, -1)
def __len__(self):
return self.redis.zcard('queue')
class DataEvent:
def __init__(self, event_id=None):
self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub()
def set(self, data):
self.redis.publish(self.event_id, pickle.dumps(data))
def wait(self):
for item in self.pubsub.listen():
if item['type'] == 'message':
return pickle.loads(item['data'])
priority_queue = RedisPriorityQueue()
def worker():
while True:
index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
increment_ip_count(client_ip, 'processing_ips')
# TODO: only increment if not valid SYSTEM__ token
start_time = time.time()
success, response, error_msg = generator(request_json_body)
end_time = time.time()
elapsed_time = end_time - start_time
with generation_elapsed_lock:
generation_elapsed.append((end_time, elapsed_time))
event = DataEvent(event_id)
event.set((success, response, error_msg))
decrement_ip_count(client_ip, 'processing_ips')
# TODO: only decrement if not valid SYSTEM__ token
def start_workers(num_workers: int):
for _ in range(num_workers):