local-llm-server/llm_server/routes/queue.py

84 lines
2.5 KiB
Python
Raw Normal View History

2023-08-23 20:12:38 -06:00
import heapq
import threading
import time
2023-08-23 20:12:38 -06:00
from llm_server import opts
2023-08-23 20:12:38 -06:00
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
2023-08-23 20:12:38 -06:00
processing_ips = set()
processing_ips_lock = threading.Lock()
2023-08-23 20:12:38 -06:00
class PriorityQueue:
def __init__(self):
self._queue = []
self._index = 0
self._cv = threading.Condition()
self._ip_count = {}
2023-08-23 20:12:38 -06:00
def put(self, item, priority):
event = DataEvent()
with self._cv:
# Check if the IP is already in the dictionary and if it has reached the limit
if item[1] in self._ip_count and self._ip_count[item[1]] >= opts.ip_in_queue_max and priority != 0:
return None # reject the request
2023-08-23 20:12:38 -06:00
heapq.heappush(self._queue, (-priority, self._index, item, event))
self._index += 1
# Increment the count for this IP
self._ip_count[item[1]] = self._ip_count.get(item[1], 0) + 1
2023-08-23 20:12:38 -06:00
self._cv.notify()
return event
def get(self):
with self._cv:
while len(self._queue) == 0:
self._cv.wait()
_, _, item, event = heapq.heappop(self._queue)
# Decrement the count for this IP
self._ip_count[item[1]] -= 1
if self._ip_count[item[1]] == 0:
del self._ip_count[item[1]] # Remove the IP from the dictionary if count is 0
return item, event
2023-08-23 20:33:49 -06:00
def __len__(self):
return len(self._queue)
2023-08-23 20:12:38 -06:00
priority_queue = PriorityQueue()
class DataEvent(threading.Event):
def __init__(self):
super().__init__()
self.data = None
def worker():
global processing_ips_lock
2023-08-23 20:12:38 -06:00
while True:
(request_json_body, client_ip, token, parameters), event = priority_queue.get()
redis.sadd('processing_ips', client_ip)
redis.incr('active_gen_workers')
start_time = time.time()
2023-08-23 20:12:38 -06:00
success, response, error_msg = generator(request_json_body)
end_time = time.time()
elapsed_time = end_time - start_time
with generation_elapsed_lock:
generation_elapsed.append((end_time, elapsed_time))
2023-08-23 20:12:38 -06:00
event.data = (success, response, error_msg)
event.set()
redis.srem('processing_ips', client_ip)
redis.decr('active_gen_workers')
2023-08-23 20:12:38 -06:00
def start_workers(num_workers: int):
2023-08-23 20:33:49 -06:00
for _ in range(num_workers):
2023-08-23 20:12:38 -06:00
threading.Thread(target=worker).start()