This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/routes/queue.py

205 lines
6.7 KiB
Python
Raw Normal View History

import json
import pickle
import time
2023-10-05 21:37:18 -06:00
from typing import Tuple
from uuid import uuid4
from redis import Redis
2023-08-23 20:12:38 -06:00
2023-10-15 20:45:01 -06:00
from llm_server import opts
2023-10-05 21:37:18 -06:00
from llm_server.cluster.cluster_config import cluster_config
2023-09-30 19:41:50 -06:00
from llm_server.custom_redis import RedisCustom, redis
2023-10-02 02:05:15 -06:00
from llm_server.database.database import get_token_ratelimit
2023-09-28 03:44:30 -06:00
def increment_ip_count(client_ip: str, redis_key):
redis.hincrby(redis_key, client_ip, 1)
2023-09-28 03:44:30 -06:00
def decrement_ip_count(client_ip: str, redis_key):
new_count = redis.hincrby(redis_key, client_ip, -1)
if new_count <= 0:
redis.hdel(redis_key, client_ip)
2023-08-23 20:12:38 -06:00
class RedisPriorityQueue:
2023-10-05 21:37:18 -06:00
def __init__(self, name, db: int = 12):
2023-10-15 20:45:01 -06:00
self.name = name
2023-09-30 19:41:50 -06:00
self.redis = RedisCustom(name, db=db)
2023-08-23 20:12:38 -06:00
2023-10-16 00:18:05 -06:00
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
2023-10-15 20:45:01 -06:00
assert item is not None
assert priority is not None
assert selected_model is not None
2023-10-04 12:57:11 -06:00
event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count:
ip_count = int(ip_count)
2023-10-02 02:05:15 -06:00
_, simultaneous_ip = get_token_ratelimit(item[2])
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
2023-09-27 19:39:04 -06:00
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
return None # reject the request
2023-09-27 19:39:04 -06:00
2023-10-15 20:45:01 -06:00
timestamp = time.time()
2023-10-16 00:18:05 -06:00
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
2023-09-27 19:39:04 -06:00
self.increment_ip_count(item[1], 'queued_ip_count')
2023-08-23 20:12:38 -06:00
return event
def get(self):
while True:
data = self.redis.zpopmin('queue')
if data:
item = json.loads(data[0][0])
client_ip = item[0][1]
2023-09-27 19:39:04 -06:00
self.decrement_ip_count(client_ip, 'queued_ip_count')
return item
2023-09-28 01:34:15 -06:00
time.sleep(0.1) # wait for something to be added to the queue
2023-10-17 11:46:39 -06:00
# def print_all_items(self):
# items = self.redis.zrange('queue', 0, -1)
# to_print = []
# for item in items:
# to_print.append(item.decode('utf-8'))
# print(f'ITEMS {self.name} -->', to_print)
2023-10-05 17:00:35 -06:00
2023-09-28 03:44:30 -06:00
def increment_ip_count(self, client_ip: str, redis_key):
2023-10-15 20:45:01 -06:00
self.redis.hincrby(redis_key, client_ip, 1)
2023-09-28 03:44:30 -06:00
def decrement_ip_count(self, client_ip: str, redis_key):
new_count = self.redis.hincrby(redis_key, client_ip, -1)
if new_count <= 0:
2023-10-04 12:47:59 -06:00
self.redis.hdel(redis_key, client_ip)
2023-08-23 20:33:49 -06:00
def __len__(self):
return self.redis.zcard('queue')
2023-08-23 20:12:38 -06:00
2023-09-28 03:44:30 -06:00
def get_queued_ip_count(self, client_ip: str):
q = self.redis.hget('queued_ip_count', client_ip)
if not q:
return 0
return 0
2023-09-30 19:41:50 -06:00
def flush(self):
self.redis.flush()
2023-10-17 11:46:39 -06:00
def items(self):
return self.redis.zrange('queue', 0, -1)
2023-10-15 20:45:01 -06:00
def cleanup(self):
now = time.time()
2023-10-17 11:46:39 -06:00
for item in self.items():
2023-10-15 20:45:01 -06:00
item_data = json.loads(item)
2023-10-17 11:46:39 -06:00
timestamp = item_data[-2]
if now - timestamp > opts.backend_generate_request_timeout:
self.redis.zrem('queue', 0, item)
data = json.loads(item.decode('utf-8'))
event_id = data[1]
client_ip = data[0][1]
self.decrement_ip_count(client_ip, 'queued_ip_count')
event = DataEvent(event_id)
event.set((False, None, 'closed'))
print('Removed timed-out item from queue:', event_id)
2023-10-15 20:45:01 -06:00
2023-08-23 20:12:38 -06:00
class DataEvent:
def __init__(self, event_id=None):
self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe(self.event_id)
2023-08-23 20:12:38 -06:00
def set(self, data):
self.redis.publish(self.event_id, pickle.dumps(data))
2023-08-23 20:12:38 -06:00
def wait(self):
2023-10-16 00:18:05 -06:00
# TODO: implement timeout
for item in self.pubsub.listen():
if item['type'] == 'message':
return pickle.loads(item['data'])
def update_active_workers(key: str, operation: str):
if operation == 'incr':
redis.incr(f'active_gen_workers:{key}')
elif operation == 'decr':
redis.decr(f'active_gen_workers:{key}')
if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0:
redis.set(f'active_gen_workers:{key}', 0)
2023-09-30 19:41:50 -06:00
def incr_active_workers(selected_model: str, backend_url: str):
update_active_workers(selected_model, 'incr')
update_active_workers(backend_url, 'incr')
2023-09-30 19:41:50 -06:00
2023-09-28 08:47:39 -06:00
2023-09-30 19:41:50 -06:00
def decr_active_workers(selected_model: str, backend_url: str):
update_active_workers(selected_model, 'decr')
update_active_workers(backend_url, 'decr')
2023-10-05 21:37:18 -06:00
class PriorityQueue:
2023-10-15 20:45:01 -06:00
def __init__(self, backends: set = None):
2023-10-05 21:37:18 -06:00
"""
Only have to load the backends once.
:param backends:
"""
self.redis = Redis(host='localhost', port=6379, db=9)
if backends:
for item in backends:
2023-10-15 20:45:01 -06:00
self.redis.sadd('backends', item)
2023-10-05 21:37:18 -06:00
def get_backends(self):
2023-10-15 20:45:01 -06:00
return {x.decode('utf-8') for x in self.redis.smembers('backends')}
2023-10-05 21:37:18 -06:00
def get_queued_ip_count(self, client_ip: str):
count = 0
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
count += queue.get_queued_ip_count(client_ip)
return count
2023-10-16 00:18:05 -06:00
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
2023-10-05 21:37:18 -06:00
queue = RedisPriorityQueue(backend_url)
2023-10-16 00:18:05 -06:00
return queue.put(item, priority, selected_model, do_stream)
2023-10-05 21:37:18 -06:00
2023-10-15 20:45:01 -06:00
def activity(self):
lines = []
status_redis = RedisCustom('worker_status')
for worker in status_redis.keys():
lines.append((worker, status_redis.getp(worker)))
return sorted(lines)
2023-10-05 21:37:18 -06:00
def len(self, model_name):
count = 0
2023-10-15 20:45:01 -06:00
backends_with_models = set()
2023-10-05 21:37:18 -06:00
for k in self.get_backends():
info = cluster_config.get_backend(k)
if info.get('model') == model_name:
2023-10-15 20:45:01 -06:00
backends_with_models.add(k)
2023-10-05 21:37:18 -06:00
for backend_url in backends_with_models:
2023-10-05 21:43:49 -06:00
count += len(RedisPriorityQueue(backend_url))
2023-10-05 21:37:18 -06:00
return count
def __len__(self):
count = 0
2023-10-15 20:45:01 -06:00
p = set()
2023-10-05 21:37:18 -06:00
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
2023-10-15 20:45:01 -06:00
p.add((backend_url, len(queue)))
2023-10-05 21:37:18 -06:00
count += len(queue)
return count
def flush(self):
for k in self.redis.keys():
q = json.loads(self.redis.get(k))
q.flush()
self.redis.set(k, json.dumps(q))
def flush_db(self):
self.redis.flushdb()
priority_queue = PriorityQueue()