73 lines
2.9 KiB
Python
73 lines
2.9 KiB
Python
import threading
|
|
import time
|
|
|
|
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
|
|
from llm_server.custom_redis import redis
|
|
from llm_server.llm.generator import generator
|
|
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
|
|
|
|
|
|
def worker():
|
|
while True:
|
|
(request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get()
|
|
if not backend_url:
|
|
backend_url = get_a_cluster_backend(selected_model)
|
|
else:
|
|
backend_url = cluster_config.validate_backend(backend_url)
|
|
backend_info = cluster_config.get_backend(backend_url)
|
|
|
|
if not selected_model:
|
|
selected_model = backend_info['model']
|
|
|
|
increment_ip_count(client_ip, 'processing_ips')
|
|
incr_active_workers(selected_model, backend_url)
|
|
|
|
need_to_wait(backend_url)
|
|
|
|
try:
|
|
if not request_json_body:
|
|
# This was a dummy request from the streaming handlers.
|
|
# The worker will let the handler do the streaming instead
|
|
# of the worker. The worker will block until the handler
|
|
# is finished. Since a lot of ratelimiting and stats are
|
|
# based off the number of active workers, we must keep
|
|
# the generation based off the workers.
|
|
pubsub = redis.pubsub()
|
|
pubsub.subscribe(event_id)
|
|
redis.publish(event_id, 'begin')
|
|
for item in pubsub.listen():
|
|
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
|
|
# Once the handler is complete, move on.
|
|
break
|
|
time.sleep(0.1)
|
|
else:
|
|
# Normal inference (not streaming).
|
|
success, response, error_msg = generator(request_json_body, backend_url)
|
|
event = DataEvent(event_id)
|
|
event.set((success, response, error_msg))
|
|
finally:
|
|
decrement_ip_count(client_ip, 'processing_ips')
|
|
decr_active_workers(selected_model, backend_url)
|
|
|
|
|
|
def start_workers(num_workers: int):
|
|
i = 0
|
|
for _ in range(num_workers):
|
|
t = threading.Thread(target=worker)
|
|
t.daemon = True
|
|
t.start()
|
|
i += 1
|
|
print(f'Started {i} inference workers.')
|
|
|
|
|
|
def need_to_wait(backend_url: str):
|
|
# We need to check the number of active workers since the streaming endpoint may be doing something.
|
|
active_workers = redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int)
|
|
concurrent_gens = cluster_config.get_backend(backend_url).get('concurrent_gens', 1)
|
|
s = time.time()
|
|
while active_workers >= concurrent_gens:
|
|
time.sleep(0.01)
|
|
e = time.time()
|
|
if e - s > 0.1:
|
|
print(f'Worker was delayed {e - s} seconds.')
|