forgot to start workers

This commit is contained in:
Cyberes 2023-08-23 20:33:49 -06:00
parent 6f8b70df54
commit 0aa52863bc
4 changed files with 13 additions and 11 deletions

View File

@ -3,10 +3,10 @@ from llm_server import opts
def generator(request_json_body): def generator(request_json_body):
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':
from oobabooga.generate import generate from .oobabooga.generate import generate
return generate(request_json_body) return generate(request_json_body)
elif opts.mode == 'hf-textgen': elif opts.mode == 'hf-textgen':
from hf_textgen.generate import generate from .hf_textgen.generate import generate
return generate(request_json_body) return generate(request_json_body)
else: else:
raise Exception raise Exception

View File

@ -22,7 +22,10 @@ class PriorityQueue:
with self._cv: with self._cv:
while len(self._queue) == 0: while len(self._queue) == 0:
self._cv.wait() self._cv.wait()
return heapq.heappop(self._queue)[-1] return heapq.heappop(self._queue)
def __len__(self):
return len(self._queue)
priority_queue = PriorityQueue() priority_queue = PriorityQueue()
@ -36,12 +39,12 @@ class DataEvent(threading.Event):
def worker(): def worker():
while True: while True:
request_json_body, client_ip, token, parameters, event = priority_queue.get() priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
success, response, error_msg = generator(request_json_body) success, response, error_msg = generator(request_json_body)
event.data = (success, response, error_msg) event.data = (success, response, error_msg)
event.set() event.set()
def start_workers(num_workers: int): def start_workers(num_workers: int):
for _ in range(3): for _ in range(num_workers):
threading.Thread(target=worker).start() threading.Thread(target=worker).start()

View File

@ -1,10 +1,7 @@
import collections
import time import time
from datetime import datetime from datetime import datetime
from threading import Semaphore, Thread from threading import Semaphore, Thread
from llm_server import opts
from llm_server.integer import ThreadSafeInteger
from llm_server.opts import concurrent_gens from llm_server.opts import concurrent_gens
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
@ -26,9 +23,8 @@ class SemaphoreCheckerThread(Thread):
proompters_1_min = 0 proompters_1_min = 0
recent_prompters = {} recent_prompters = {}
def __init__(self, semaphore): def __init__(self):
Thread.__init__(self) Thread.__init__(self)
self.semaphore = semaphore
self.daemon = True self.daemon = True
def run(self): def run(self):

View File

@ -10,6 +10,7 @@ from llm_server.database import init_db
from llm_server.helpers import resolve_path from llm_server.helpers import resolve_path
from llm_server.routes.cache import cache from llm_server.routes.cache import cache
from llm_server.routes.helpers.http import cache_control from llm_server.routes.helpers.http import cache_control
from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
@ -53,7 +54,9 @@ if not opts.verify_ssl:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
SemaphoreCheckerThread(concurrent_semaphore).start() start_workers(opts.concurrent_gens)
SemaphoreCheckerThread().start()
app = Flask(__name__) app = Flask(__name__)
cache.init_app(app) cache.init_app(app)