199 lines
6.6 KiB
Python
199 lines
6.6 KiB
Python
import pickle
|
|
import time
|
|
from typing import Tuple
|
|
from uuid import uuid4
|
|
|
|
import ujson as json
|
|
from redis import Redis
|
|
|
|
from llm_server import opts
|
|
from llm_server.cluster.cluster_config import cluster_config
|
|
from llm_server.custom_redis import RedisCustom, redis
|
|
from llm_server.database.database import get_token_ratelimit
|
|
|
|
|
|
def increment_ip_count(client_ip: str, redis_key):
|
|
redis.hincrby(redis_key, client_ip, 1)
|
|
|
|
|
|
def decrement_ip_count(client_ip: str, redis_key):
|
|
new_count = redis.hincrby(redis_key, client_ip, -1)
|
|
if new_count <= 0:
|
|
redis.hdel(redis_key, client_ip)
|
|
|
|
|
|
class RedisPriorityQueue:
|
|
def __init__(self, name, db: int = 12):
|
|
self.name = name
|
|
self.redis = RedisCustom(name, db=db)
|
|
|
|
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
|
|
# TODO: remove this when we're sure nothing strange is happening
|
|
assert item is not None
|
|
assert priority is not None
|
|
assert selected_model is not None
|
|
|
|
event = DataEvent()
|
|
|
|
# Check if the IP is already in the dictionary and if it has reached the limit
|
|
ip_count = self.get_ip_request_count(item[1])
|
|
_, simultaneous_ip = get_token_ratelimit(item[2])
|
|
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
|
|
print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
|
|
return None # reject the request
|
|
|
|
timestamp = time.time()
|
|
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
|
|
return event
|
|
|
|
def get(self):
|
|
while True:
|
|
data = self.redis.zpopmin('queue')
|
|
if data:
|
|
item = json.loads(data[0][0])
|
|
return item
|
|
time.sleep(0.1) # wait for something to be added to the queue
|
|
|
|
def __len__(self):
|
|
return self.redis.zcard('queue')
|
|
|
|
def get_ip_request_count(self, client_ip: str):
|
|
"""
|
|
Get the number of requests in the queue from a specific IP.
|
|
This is a bit inefficient since we iterate over the entire queue, but
|
|
keeps the queue as a single point of truth instead of tracking a separate hashed
|
|
set which can get confusing.
|
|
If we run into slowdowns in the future, we should go back to the hashed set approach.
|
|
:param client_ip:
|
|
:return:
|
|
"""
|
|
start_time = time.time()
|
|
items = self.redis.zrange('queue', 0, -1)
|
|
count = 0
|
|
for item in items:
|
|
item_data = json.loads(item)
|
|
if item_data[0][1] == client_ip:
|
|
count += 1
|
|
elapsed_time = time.time() - start_time
|
|
if elapsed_time > 0.5:
|
|
raise Exception(f"!!! get_ip_request_count took {elapsed_time} seconds to execute !!!")
|
|
return count
|
|
|
|
def flush(self):
|
|
self.redis.flush()
|
|
|
|
def items(self):
|
|
return self.redis.zrange('queue', 0, -1)
|
|
|
|
def cleanup(self):
|
|
now = time.time()
|
|
for item in self.items():
|
|
item_data = json.loads(item)
|
|
timestamp = item_data[-2]
|
|
if now - timestamp > opts.backend_generate_request_timeout:
|
|
self.redis.zrem('queue', 0, item)
|
|
event_id = item_data[1]
|
|
event = DataEvent(event_id)
|
|
event.set((False, None, 'closed'))
|
|
print('Removed timed-out item from queue:', event_id)
|
|
|
|
|
|
class DataEvent:
|
|
def __init__(self, event_id=None):
|
|
self.event_id = event_id if event_id else str(uuid4())
|
|
self.redis = Redis(host='localhost', port=6379, db=14)
|
|
self.pubsub = self.redis.pubsub()
|
|
self.pubsub.subscribe(self.event_id)
|
|
|
|
def set(self, data):
|
|
self.redis.publish(self.event_id, pickle.dumps(data))
|
|
|
|
def wait(self):
|
|
for item in self.pubsub.listen():
|
|
if item['type'] == 'message':
|
|
return pickle.loads(item['data'])
|
|
|
|
|
|
def update_active_workers(key: str, operation: str):
|
|
if operation == 'incr':
|
|
redis.incr(f'active_gen_workers:{key}')
|
|
elif operation == 'decr':
|
|
redis.decr(f'active_gen_workers:{key}')
|
|
if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0:
|
|
redis.set(f'active_gen_workers:{key}', 0)
|
|
|
|
|
|
def incr_active_workers(selected_model: str, backend_url: str):
|
|
update_active_workers(selected_model, 'incr')
|
|
update_active_workers(backend_url, 'incr')
|
|
|
|
|
|
def decr_active_workers(selected_model: str, backend_url: str):
|
|
update_active_workers(selected_model, 'decr')
|
|
update_active_workers(backend_url, 'decr')
|
|
|
|
|
|
class PriorityQueue:
|
|
def __init__(self, backends: set = None):
|
|
"""
|
|
Only have to load the backends once.
|
|
:param backends:
|
|
"""
|
|
self.redis = Redis(host='localhost', port=6379, db=9)
|
|
if backends:
|
|
for item in backends:
|
|
self.redis.sadd('backends', item)
|
|
|
|
def get_backends(self):
|
|
return {x.decode('utf-8') for x in self.redis.smembers('backends')}
|
|
|
|
def get_queued_ip_count(self, client_ip: str):
|
|
count = 0
|
|
for backend_url in self.get_backends():
|
|
queue = RedisPriorityQueue(backend_url)
|
|
count += queue.get_ip_request_count(client_ip)
|
|
return count
|
|
|
|
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
|
|
queue = RedisPriorityQueue(backend_url)
|
|
return queue.put(item, priority, selected_model, do_stream)
|
|
|
|
def activity(self):
|
|
lines = []
|
|
status_redis = RedisCustom('worker_status')
|
|
for worker in status_redis.keys():
|
|
lines.append((worker, status_redis.getp(worker)))
|
|
return sorted(lines)
|
|
|
|
def len(self, model_name):
|
|
count = 0
|
|
backends_with_models = set()
|
|
for k in self.get_backends():
|
|
info = cluster_config.get_backend(k)
|
|
if info.get('model') == model_name:
|
|
backends_with_models.add(k)
|
|
for backend_url in backends_with_models:
|
|
count += len(RedisPriorityQueue(backend_url))
|
|
return count
|
|
|
|
def __len__(self):
|
|
count = 0
|
|
p = set()
|
|
for backend_url in self.get_backends():
|
|
queue = RedisPriorityQueue(backend_url)
|
|
p.add((backend_url, len(queue)))
|
|
count += len(queue)
|
|
return count
|
|
|
|
def flush(self):
|
|
for k in self.redis.keys():
|
|
q = json.loads(self.redis.get(k))
|
|
q.flush()
|
|
self.redis.set(k, json.dumps(q))
|
|
|
|
def flush_db(self):
|
|
self.redis.flushdb()
|
|
|
|
|
|
priority_queue = PriorityQueue()
|