forgot to start workers

This commit is contained in:
Cyberes 2023-08-23 20:33:49 -06:00
parent 6f8b70df54
commit 0aa52863bc
4 changed files with 13 additions and 11 deletions

View File

@ -3,10 +3,10 @@ from llm_server import opts
def generator(request_json_body):
if opts.mode == 'oobabooga':
from oobabooga.generate import generate
from .oobabooga.generate import generate
return generate(request_json_body)
elif opts.mode == 'hf-textgen':
from hf_textgen.generate import generate
from .hf_textgen.generate import generate
return generate(request_json_body)
else:
raise Exception

View File

@ -22,7 +22,10 @@ class PriorityQueue:
with self._cv:
while len(self._queue) == 0:
self._cv.wait()
return heapq.heappop(self._queue)[-1]
return heapq.heappop(self._queue)
def __len__(self):
return len(self._queue)
priority_queue = PriorityQueue()
@ -36,12 +39,12 @@ class DataEvent(threading.Event):
def worker():
while True:
request_json_body, client_ip, token, parameters, event = priority_queue.get()
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
success, response, error_msg = generator(request_json_body)
event.data = (success, response, error_msg)
event.set()
def start_workers(num_workers: int):
for _ in range(3):
for _ in range(num_workers):
threading.Thread(target=worker).start()

View File

@ -1,10 +1,7 @@
import collections
import time
from datetime import datetime
from threading import Semaphore, Thread
from llm_server import opts
from llm_server.integer import ThreadSafeInteger
from llm_server.opts import concurrent_gens
from llm_server.routes.cache import redis
@ -26,9 +23,8 @@ class SemaphoreCheckerThread(Thread):
proompters_1_min = 0
recent_prompters = {}
def __init__(self, semaphore):
def __init__(self):
Thread.__init__(self)
self.semaphore = semaphore
self.daemon = True
def run(self):

View File

@ -10,6 +10,7 @@ from llm_server.database import init_db
from llm_server.helpers import resolve_path
from llm_server.routes.cache import cache
from llm_server.routes.helpers.http import cache_control
from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore
from llm_server.routes.v1 import bp
@ -53,7 +54,9 @@ if not opts.verify_ssl:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
SemaphoreCheckerThread(concurrent_semaphore).start()
start_workers(opts.concurrent_gens)
SemaphoreCheckerThread().start()
app = Flask(__name__)
cache.init_app(app)