forgot to start workers
This commit is contained in:
parent
6f8b70df54
commit
0aa52863bc
|
@ -3,10 +3,10 @@ from llm_server import opts
|
|||
|
||||
def generator(request_json_body):
|
||||
if opts.mode == 'oobabooga':
|
||||
from oobabooga.generate import generate
|
||||
from .oobabooga.generate import generate
|
||||
return generate(request_json_body)
|
||||
elif opts.mode == 'hf-textgen':
|
||||
from hf_textgen.generate import generate
|
||||
from .hf_textgen.generate import generate
|
||||
return generate(request_json_body)
|
||||
else:
|
||||
raise Exception
|
||||
|
|
|
@ -22,7 +22,10 @@ class PriorityQueue:
|
|||
with self._cv:
|
||||
while len(self._queue) == 0:
|
||||
self._cv.wait()
|
||||
return heapq.heappop(self._queue)[-1]
|
||||
return heapq.heappop(self._queue)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._queue)
|
||||
|
||||
|
||||
priority_queue = PriorityQueue()
|
||||
|
@ -36,12 +39,12 @@ class DataEvent(threading.Event):
|
|||
|
||||
def worker():
|
||||
while True:
|
||||
request_json_body, client_ip, token, parameters, event = priority_queue.get()
|
||||
priority, index, (request_json_body, client_ip, token, parameters), event = priority_queue.get()
|
||||
success, response, error_msg = generator(request_json_body)
|
||||
event.data = (success, response, error_msg)
|
||||
event.set()
|
||||
|
||||
|
||||
def start_workers(num_workers: int):
|
||||
for _ in range(3):
|
||||
for _ in range(num_workers):
|
||||
threading.Thread(target=worker).start()
|
||||
|
|
|
@ -1,10 +1,7 @@
|
|||
import collections
|
||||
import time
|
||||
from datetime import datetime
|
||||
from threading import Semaphore, Thread
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.integer import ThreadSafeInteger
|
||||
from llm_server.opts import concurrent_gens
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
@ -26,9 +23,8 @@ class SemaphoreCheckerThread(Thread):
|
|||
proompters_1_min = 0
|
||||
recent_prompters = {}
|
||||
|
||||
def __init__(self, semaphore):
|
||||
def __init__(self):
|
||||
Thread.__init__(self)
|
||||
self.semaphore = semaphore
|
||||
self.daemon = True
|
||||
|
||||
def run(self):
|
||||
|
|
|
@ -10,6 +10,7 @@ from llm_server.database import init_db
|
|||
from llm_server.helpers import resolve_path
|
||||
from llm_server.routes.cache import cache
|
||||
from llm_server.routes.helpers.http import cache_control
|
||||
from llm_server.routes.queue import start_workers
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore
|
||||
from llm_server.routes.v1 import bp
|
||||
|
||||
|
@ -53,7 +54,9 @@ if not opts.verify_ssl:
|
|||
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
SemaphoreCheckerThread(concurrent_semaphore).start()
|
||||
start_workers(opts.concurrent_gens)
|
||||
|
||||
SemaphoreCheckerThread().start()
|
||||
|
||||
app = Flask(__name__)
|
||||
cache.init_app(app)
|
||||
|
|
Reference in New Issue