From 0aa52863bc7cf46153108be3ec8341c3b0724896 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Wed, 23 Aug 2023 20:33:49 -0600 Subject: [PATCH] forgot to start workers --- llm_server/llm/generator.py | 4 ++-- llm_server/routes/queue.py | 9 ++++++--- llm_server/routes/stats.py | 6 +----- server.py | 5 ++++- 4 files changed, 13 insertions(+), 11 deletions(-) diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py index 88d829f..b2281f0 100644 --- a/llm_server/llm/generator.py +++ b/llm_server/llm/generator.py @@ -3,10 +3,10 @@ from llm_server import opts def generator(request_json_body): if opts.mode == 'oobabooga': - from oobabooga.generate import generate + from .oobabooga.generate import generate return generate(request_json_body) elif opts.mode == 'hf-textgen': - from hf_textgen.generate import generate + from .hf_textgen.generate import generate return generate(request_json_body) else: raise Exception diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index c2b9a50..d3dfeae 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -22,7 +22,10 @@ class PriorityQueue: with self._cv: while len(self._queue) == 0: self._cv.wait() - return heapq.heappop(self._queue)[-1] + return heapq.heappop(self._queue) + + def __len__(self): + return len(self._queue) priority_queue = PriorityQueue() @@ -36,12 +39,12 @@ class DataEvent(threading.Event): def worker(): while True: - request_json_body, client_ip, token, parameters, event = priority_queue.get() + priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get() success, response, error_msg = generator(request_json_body) event.data = (success, response, error_msg) event.set() def start_workers(num_workers: int): - for _ in range(3): + for _ in range(num_workers): threading.Thread(target=worker).start() diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index 975b0ce..83b0fb4 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -1,10 +1,7 @@ -import collections import time from datetime import datetime from threading import Semaphore, Thread -from llm_server import opts -from llm_server.integer import ThreadSafeInteger from llm_server.opts import concurrent_gens from llm_server.routes.cache import redis @@ -26,9 +23,8 @@ class SemaphoreCheckerThread(Thread): proompters_1_min = 0 recent_prompters = {} - def __init__(self, semaphore): + def __init__(self): Thread.__init__(self) - self.semaphore = semaphore self.daemon = True def run(self): diff --git a/server.py b/server.py index 85d5dee..76266fe 100644 --- a/server.py +++ b/server.py @@ -10,6 +10,7 @@ from llm_server.database import init_db from llm_server.helpers import resolve_path from llm_server.routes.cache import cache from llm_server.routes.helpers.http import cache_control +from llm_server.routes.queue import start_workers from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore from llm_server.routes.v1 import bp @@ -53,7 +54,9 @@ if not opts.verify_ssl: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) -SemaphoreCheckerThread(concurrent_semaphore).start() +start_workers(opts.concurrent_gens) + +SemaphoreCheckerThread().start() app = Flask(__name__) cache.init_app(app)