This repository has been archived on 2024-10-27. You can view files and clone it, but cannot push or open issues or pull requests.
local-llm-server/llm_server/workers/inferencer.py

56 lines
2.3 KiB
Python
Raw Normal View History

import threading
import time
2023-10-05 21:37:18 -06:00
from uuid import uuid4
2023-10-05 21:37:18 -06:00
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis, RedisCustom
2023-09-29 00:09:44 -06:00
from llm_server.llm.generator import generator
2023-10-05 21:37:18 -06:00
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, RedisPriorityQueue, PriorityQueue, priority_queue
2023-10-05 21:37:18 -06:00
def worker(backend_url):
queue = RedisPriorityQueue(backend_url)
while True:
2023-10-05 21:37:18 -06:00
(request_json_body, client_ip, token, parameters), event_id, selected_model = queue.get()
2023-10-02 11:11:48 -06:00
backend_info = cluster_config.get_backend(backend_url)
2023-09-30 19:41:50 -06:00
if not selected_model:
2023-10-02 11:11:48 -06:00
selected_model = backend_info['model']
2023-09-30 19:41:50 -06:00
increment_ip_count(client_ip, 'processing_ips')
2023-09-30 19:41:50 -06:00
incr_active_workers(selected_model, backend_url)
try:
2023-10-05 20:14:28 -06:00
if not request_json_body:
# This was a dummy request from the streaming handlers.
# The worker will let the handler do the streaming instead
# of the worker. The worker will block until the handler
# is finished. Since a lot of ratelimiting and stats are
# based off the number of active workers, we must keep
# the generation based off the workers.
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
redis.publish(event_id, 'begin')
for item in pubsub.listen():
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
break
time.sleep(0.1)
else:
# Normal inference (not streaming).
success, response, error_msg = generator(request_json_body, backend_url)
event = DataEvent(event_id)
event.set((success, response, error_msg))
finally:
decrement_ip_count(client_ip, 'processing_ips')
2023-09-30 19:41:50 -06:00
decr_active_workers(selected_model, backend_url)
2023-10-05 21:37:18 -06:00
def start_workers(cluster: dict):
i = 0
2023-10-05 21:37:18 -06:00
for item in cluster:
for _ in range(item['concurrent_gens']):
t = threading.Thread(target=worker, args=(item['backend_url'],))
t.daemon = True
t.start()
i += 1
print(f'Started {i} inference workers.')