local-llm-server/llm_server/workers/blocking.py

59 lines
1.8 KiB
Python

import json
import threading
import time
from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.queue import DataEvent, decrement_ip_count, increment_ip_count, priority_queue
def worker():
while True:
need_to_wait()
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
need_to_wait()
increment_ip_count(client_ip, 'processing_ips')
redis.incr('active_gen_workers')
if not request_json_body:
# This was a dummy request from the websocket handler.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
continue
try:
start_time = time.time()
success, response, error_msg = generator(request_json_body)
end_time = time.time()
elapsed_time = end_time - start_time
# redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time)))
event = DataEvent(event_id)
event.set((success, response, error_msg))
finally:
decrement_ip_count(client_ip, 'processing_ips')
redis.decr('active_gen_workers')
def start_workers(num_workers: int):
i = 0
for _ in range(num_workers):
t = threading.Thread(target=worker)
t.daemon = True
t.start()
i += 1
print(f'Started {i} inference workers.')
def need_to_wait():
# We need to check the number of active workers since the streaming endpoint may be doing something.
active_workers = redis.get('active_gen_workers', int, 0)
s = time.time()
while active_workers >= opts.concurrent_gens:
time.sleep(0.01)
e = time.time()
if e - s > 0.5:
print(f'Worker was delayed {e - s} seconds.')