add a queue system
This commit is contained in:
parent
a79d67adbb
commit
6f8b70df54
|
@ -29,13 +29,16 @@ def init_db(db_path):
|
|||
)
|
||||
''')
|
||||
c.execute('''
|
||||
CREATE TABLE token_auth
|
||||
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
|
||||
CREATE TABLE token_auth (
|
||||
token TEXT UNIQUE,
|
||||
type TEXT NOT NULL,
|
||||
priority INTEGER default 9999,
|
||||
uses INTEGER default 0,
|
||||
max_uses INTEGER,
|
||||
expire INTEGER,
|
||||
disabled BOOLEAN default 0
|
||||
)
|
||||
''')
|
||||
# c.execute('''
|
||||
# CREATE TABLE leeches
|
||||
# (url TEXT, online TEXT)
|
||||
# ''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
|
|
@ -0,0 +1,12 @@
|
|||
from llm_server import opts
|
||||
|
||||
|
||||
def generator(request_json_body):
|
||||
if opts.mode == 'oobabooga':
|
||||
from oobabooga.generate import generate
|
||||
return generate(request_json_body)
|
||||
elif opts.mode == 'hf-textgen':
|
||||
from hf_textgen.generate import generate
|
||||
return generate(request_json_body)
|
||||
else:
|
||||
raise Exception
|
|
@ -0,0 +1,47 @@
|
|||
import heapq
|
||||
import threading
|
||||
|
||||
from llm_server.llm.generator import generator
|
||||
|
||||
|
||||
class PriorityQueue:
|
||||
def __init__(self):
|
||||
self._queue = []
|
||||
self._index = 0
|
||||
self._cv = threading.Condition()
|
||||
|
||||
def put(self, item, priority):
|
||||
event = DataEvent()
|
||||
with self._cv:
|
||||
heapq.heappush(self._queue, (-priority, self._index, item, event))
|
||||
self._index += 1
|
||||
self._cv.notify()
|
||||
return event
|
||||
|
||||
def get(self):
|
||||
with self._cv:
|
||||
while len(self._queue) == 0:
|
||||
self._cv.wait()
|
||||
return heapq.heappop(self._queue)[-1]
|
||||
|
||||
|
||||
priority_queue = PriorityQueue()
|
||||
|
||||
|
||||
class DataEvent(threading.Event):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.data = None
|
||||
|
||||
|
||||
def worker():
|
||||
while True:
|
||||
request_json_body, client_ip, token, parameters, event = priority_queue.get()
|
||||
success, response, error_msg = generator(request_json_body)
|
||||
event.data = (success, response, error_msg)
|
||||
event.set()
|
||||
|
||||
|
||||
def start_workers(num_workers: int):
|
||||
for _ in range(3):
|
||||
threading.Thread(target=worker).start()
|
|
@ -1,3 +1,4 @@
|
|||
import sqlite3
|
||||
import time
|
||||
|
||||
from flask import jsonify, request
|
||||
|
@ -6,25 +7,16 @@ from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore
|
|||
from . import bp
|
||||
from ..cache import redis
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import cache_control, validate_json
|
||||
from ..helpers.http import validate_json
|
||||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...database import log_prompt
|
||||
from ...helpers import safe_list_get
|
||||
|
||||
|
||||
def generator(request_json_body):
|
||||
if opts.mode == 'oobabooga':
|
||||
from ...llm.oobabooga.generate import generate
|
||||
return generate(request_json_body)
|
||||
elif opts.mode == 'hf-textgen':
|
||||
from ...llm.hf_textgen.generate import generate
|
||||
return generate(request_json_body)
|
||||
else:
|
||||
raise Exception
|
||||
DEFAULT_PRIORITY = 9999
|
||||
|
||||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
@cache_control(-1)
|
||||
def generate():
|
||||
request_valid_json, request_json_body = validate_json(request.data)
|
||||
if not request_valid_json:
|
||||
|
@ -40,12 +32,30 @@ def generate():
|
|||
|
||||
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
|
||||
|
||||
token = request.headers.get('X-Api-Key')
|
||||
|
||||
parameters = request_json_body.copy()
|
||||
del parameters['prompt']
|
||||
|
||||
success, response, error_msg = generator(request_json_body)
|
||||
token = request.headers.get('X-Api-Key')
|
||||
priority = None
|
||||
if token:
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
priority = result[0]
|
||||
conn.close()
|
||||
|
||||
if priority is None:
|
||||
priority = DEFAULT_PRIORITY
|
||||
else:
|
||||
print(f'Token {token} was given priority {priority}.')
|
||||
|
||||
# success, response, error_msg = generator(request_json_body)
|
||||
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
|
||||
event.wait()
|
||||
success, response, error_msg = event.data
|
||||
|
||||
if not success:
|
||||
if opts.mode == 'oobabooga':
|
||||
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
|
||||
|
@ -99,8 +109,3 @@ def generate():
|
|||
'error': 'the backend did not return valid JSON',
|
||||
**response_json_body
|
||||
}), 200
|
||||
|
||||
# @openai_bp.route('/chat/completions', methods=['POST'])
|
||||
# def generate_openai():
|
||||
# print(request.data)
|
||||
# return '', 200
|
||||
|
|
Reference in New Issue