59 lines
1.8 KiB
Python
59 lines
1.8 KiB
Python
|
import json
|
||
|
import threading
|
||
|
import time
|
||
|
|
||
|
from llm_server import opts
|
||
|
from llm_server.llm.generator import generator
|
||
|
from llm_server.routes.cache import redis
|
||
|
from llm_server.routes.queue import DataEvent, decrement_ip_count, increment_ip_count, priority_queue
|
||
|
|
||
|
|
||
|
def worker():
|
||
|
while True:
|
||
|
need_to_wait()
|
||
|
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
||
|
need_to_wait()
|
||
|
|
||
|
increment_ip_count(client_ip, 'processing_ips')
|
||
|
redis.incr('active_gen_workers')
|
||
|
|
||
|
if not request_json_body:
|
||
|
# This was a dummy request from the websocket handler.
|
||
|
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
|
||
|
continue
|
||
|
|
||
|
try:
|
||
|
start_time = time.time()
|
||
|
success, response, error_msg = generator(request_json_body)
|
||
|
end_time = time.time()
|
||
|
|
||
|
elapsed_time = end_time - start_time
|
||
|
# redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time)))
|
||
|
|
||
|
event = DataEvent(event_id)
|
||
|
event.set((success, response, error_msg))
|
||
|
finally:
|
||
|
decrement_ip_count(client_ip, 'processing_ips')
|
||
|
redis.decr('active_gen_workers')
|
||
|
|
||
|
|
||
|
def start_workers(num_workers: int):
|
||
|
i = 0
|
||
|
for _ in range(num_workers):
|
||
|
t = threading.Thread(target=worker)
|
||
|
t.daemon = True
|
||
|
t.start()
|
||
|
i += 1
|
||
|
print(f'Started {i} inference workers.')
|
||
|
|
||
|
|
||
|
def need_to_wait():
|
||
|
# We need to check the number of active workers since the streaming endpoint may be doing something.
|
||
|
active_workers = redis.get('active_gen_workers', int, 0)
|
||
|
s = time.time()
|
||
|
while active_workers >= opts.concurrent_gens:
|
||
|
time.sleep(0.01)
|
||
|
e = time.time()
|
||
|
if e - s > 0.5:
|
||
|
print(f'Worker was delayed {e - s} seconds.')
|