add a queue system

This commit is contained in:
Cyberes 2023-08-23 20:12:38 -06:00
parent a79d67adbb
commit 6f8b70df54
4 changed files with 93 additions and 26 deletions

View File

@ -29,13 +29,16 @@ def init_db(db_path):
)
''')
c.execute('''
CREATE TABLE token_auth
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
CREATE TABLE token_auth (
token TEXT UNIQUE,
type TEXT NOT NULL,
priority INTEGER default 9999,
uses INTEGER default 0,
max_uses INTEGER,
expire INTEGER,
disabled BOOLEAN default 0
)
''')
# c.execute('''
# CREATE TABLE leeches
# (url TEXT, online TEXT)
# ''')
conn.commit()
conn.close()

View File

@ -0,0 +1,12 @@
from llm_server import opts
def generator(request_json_body):
if opts.mode == 'oobabooga':
from oobabooga.generate import generate
return generate(request_json_body)
elif opts.mode == 'hf-textgen':
from hf_textgen.generate import generate
return generate(request_json_body)
else:
raise Exception

View File

@ -0,0 +1,47 @@
import heapq
import threading
from llm_server.llm.generator import generator
class PriorityQueue:
def __init__(self):
self._queue = []
self._index = 0
self._cv = threading.Condition()
def put(self, item, priority):
event = DataEvent()
with self._cv:
heapq.heappush(self._queue, (-priority, self._index, item, event))
self._index += 1
self._cv.notify()
return event
def get(self):
with self._cv:
while len(self._queue) == 0:
self._cv.wait()
return heapq.heappop(self._queue)[-1]
priority_queue = PriorityQueue()
class DataEvent(threading.Event):
def __init__(self):
super().__init__()
self.data = None
def worker():
while True:
request_json_body, client_ip, token, parameters, event = priority_queue.get()
success, response, error_msg = generator(request_json_body)
event.data = (success, response, error_msg)
event.set()
def start_workers(num_workers: int):
for _ in range(3):
threading.Thread(target=worker).start()

View File

@ -1,3 +1,4 @@
import sqlite3
import time
from flask import jsonify, request
@ -6,25 +7,16 @@ from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore
from . import bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from ..helpers.http import cache_control, validate_json
from ..helpers.http import validate_json
from ..queue import priority_queue
from ... import opts
from ...database import log_prompt
from ...helpers import safe_list_get
def generator(request_json_body):
if opts.mode == 'oobabooga':
from ...llm.oobabooga.generate import generate
return generate(request_json_body)
elif opts.mode == 'hf-textgen':
from ...llm.hf_textgen.generate import generate
return generate(request_json_body)
else:
raise Exception
DEFAULT_PRIORITY = 9999
@bp.route('/generate', methods=['POST'])
@cache_control(-1)
def generate():
request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json:
@ -40,12 +32,30 @@ def generate():
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
token = request.headers.get('X-Api-Key')
parameters = request_json_body.copy()
del parameters['prompt']
success, response, error_msg = generator(request_json_body)
token = request.headers.get('X-Api-Key')
priority = None
if token:
conn = sqlite3.connect(opts.database_path)
cursor = conn.cursor()
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,))
result = cursor.fetchone()
if result:
priority = result[0]
conn.close()
if priority is None:
priority = DEFAULT_PRIORITY
else:
print(f'Token {token} was given priority {priority}.')
# success, response, error_msg = generator(request_json_body)
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
event.wait()
success, response, error_msg = event.data
if not success:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
@ -99,8 +109,3 @@ def generate():
'error': 'the backend did not return valid JSON',
**response_json_body
}), 200
# @openai_bp.route('/chat/completions', methods=['POST'])
# def generate_openai():
# print(request.data)
# return '', 200