local-llm-server/llm_server/workers/inferencer.py

73 lines
2.8 KiB
Python

import threading
import time
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
from llm_server.custom_redis import redis
from llm_server.llm.generator import generator
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
def worker():
while True:
(request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get()
if not backend_url:
backend_url = get_a_cluster_backend(selected_model)
backend_info = cluster_config.get_backend(backend_url)
# The backend could have died between when the request was
# submitted and now, so let's double check it's still online.
if not backend_info['online']:
old = backend_url
backend_url = get_a_cluster_backend()
backend_info = cluster_config.get_backend(backend_url)
print(f'Backend {old} offline. Request was redirected to {backend_url}')
del old # gc
if not selected_model:
selected_model = backend_info['model']
# This wait time will be "invisible", meaning the worker may as
# well be still waiting to get an item from the queue.
need_to_wait(backend_url)
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers(selected_model, backend_url)
if not request_json_body:
# This was a dummy request from the websocket handlers.
# We're going to let the websocket handler decrement
# processing_ips and active_gen_workers.
event = DataEvent(event_id)
event.set((True, None, None))
continue
try:
success, response, error_msg = generator(request_json_body, backend_url)
event = DataEvent(event_id)
event.set((success, response, error_msg))
finally:
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url)
def start_workers(num_workers: int):
i = 0
for _ in range(num_workers):
t = threading.Thread(target=worker)
t.daemon = True
t.start()
i += 1
print(f'Started {i} inference workers.')
def need_to_wait(backend_url: str):
# We need to check the number of active workers since the streaming endpoint may be doing something.
active_workers = redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int)
concurrent_gens = cluster_config.get_backend(backend_url).get('concurrent_gens', 1)
s = time.time()
while active_workers >= concurrent_gens:
time.sleep(0.01)
e = time.time()
if e - s > 0.5:
print(f'Worker was delayed {e - s} seconds.')