add a queue system
This commit is contained in:
parent
a79d67adbb
commit
6f8b70df54
|
@ -29,13 +29,16 @@ def init_db(db_path):
|
||||||
)
|
)
|
||||||
''')
|
''')
|
||||||
c.execute('''
|
c.execute('''
|
||||||
CREATE TABLE token_auth
|
CREATE TABLE token_auth (
|
||||||
(token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0)
|
token TEXT UNIQUE,
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
priority INTEGER default 9999,
|
||||||
|
uses INTEGER default 0,
|
||||||
|
max_uses INTEGER,
|
||||||
|
expire INTEGER,
|
||||||
|
disabled BOOLEAN default 0
|
||||||
|
)
|
||||||
''')
|
''')
|
||||||
# c.execute('''
|
|
||||||
# CREATE TABLE leeches
|
|
||||||
# (url TEXT, online TEXT)
|
|
||||||
# ''')
|
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,12 @@
|
||||||
|
from llm_server import opts
|
||||||
|
|
||||||
|
|
||||||
|
def generator(request_json_body):
|
||||||
|
if opts.mode == 'oobabooga':
|
||||||
|
from oobabooga.generate import generate
|
||||||
|
return generate(request_json_body)
|
||||||
|
elif opts.mode == 'hf-textgen':
|
||||||
|
from hf_textgen.generate import generate
|
||||||
|
return generate(request_json_body)
|
||||||
|
else:
|
||||||
|
raise Exception
|
|
@ -0,0 +1,47 @@
|
||||||
|
import heapq
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from llm_server.llm.generator import generator
|
||||||
|
|
||||||
|
|
||||||
|
class PriorityQueue:
|
||||||
|
def __init__(self):
|
||||||
|
self._queue = []
|
||||||
|
self._index = 0
|
||||||
|
self._cv = threading.Condition()
|
||||||
|
|
||||||
|
def put(self, item, priority):
|
||||||
|
event = DataEvent()
|
||||||
|
with self._cv:
|
||||||
|
heapq.heappush(self._queue, (-priority, self._index, item, event))
|
||||||
|
self._index += 1
|
||||||
|
self._cv.notify()
|
||||||
|
return event
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
with self._cv:
|
||||||
|
while len(self._queue) == 0:
|
||||||
|
self._cv.wait()
|
||||||
|
return heapq.heappop(self._queue)[-1]
|
||||||
|
|
||||||
|
|
||||||
|
priority_queue = PriorityQueue()
|
||||||
|
|
||||||
|
|
||||||
|
class DataEvent(threading.Event):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.data = None
|
||||||
|
|
||||||
|
|
||||||
|
def worker():
|
||||||
|
while True:
|
||||||
|
request_json_body, client_ip, token, parameters, event = priority_queue.get()
|
||||||
|
success, response, error_msg = generator(request_json_body)
|
||||||
|
event.data = (success, response, error_msg)
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
|
||||||
|
def start_workers(num_workers: int):
|
||||||
|
for _ in range(3):
|
||||||
|
threading.Thread(target=worker).start()
|
|
@ -1,3 +1,4 @@
|
||||||
|
import sqlite3
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from flask import jsonify, request
|
from flask import jsonify, request
|
||||||
|
@ -6,25 +7,16 @@ from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore
|
||||||
from . import bp
|
from . import bp
|
||||||
from ..cache import redis
|
from ..cache import redis
|
||||||
from ..helpers.client import format_sillytavern_err
|
from ..helpers.client import format_sillytavern_err
|
||||||
from ..helpers.http import cache_control, validate_json
|
from ..helpers.http import validate_json
|
||||||
|
from ..queue import priority_queue
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database import log_prompt
|
from ...database import log_prompt
|
||||||
from ...helpers import safe_list_get
|
from ...helpers import safe_list_get
|
||||||
|
|
||||||
|
DEFAULT_PRIORITY = 9999
|
||||||
def generator(request_json_body):
|
|
||||||
if opts.mode == 'oobabooga':
|
|
||||||
from ...llm.oobabooga.generate import generate
|
|
||||||
return generate(request_json_body)
|
|
||||||
elif opts.mode == 'hf-textgen':
|
|
||||||
from ...llm.hf_textgen.generate import generate
|
|
||||||
return generate(request_json_body)
|
|
||||||
else:
|
|
||||||
raise Exception
|
|
||||||
|
|
||||||
|
|
||||||
@bp.route('/generate', methods=['POST'])
|
@bp.route('/generate', methods=['POST'])
|
||||||
@cache_control(-1)
|
|
||||||
def generate():
|
def generate():
|
||||||
request_valid_json, request_json_body = validate_json(request.data)
|
request_valid_json, request_json_body = validate_json(request.data)
|
||||||
if not request_valid_json:
|
if not request_valid_json:
|
||||||
|
@ -40,12 +32,30 @@ def generate():
|
||||||
|
|
||||||
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
|
SemaphoreCheckerThread.recent_prompters[client_ip] = time.time()
|
||||||
|
|
||||||
token = request.headers.get('X-Api-Key')
|
|
||||||
|
|
||||||
parameters = request_json_body.copy()
|
parameters = request_json_body.copy()
|
||||||
del parameters['prompt']
|
del parameters['prompt']
|
||||||
|
|
||||||
success, response, error_msg = generator(request_json_body)
|
token = request.headers.get('X-Api-Key')
|
||||||
|
priority = None
|
||||||
|
if token:
|
||||||
|
conn = sqlite3.connect(opts.database_path)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,))
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
priority = result[0]
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if priority is None:
|
||||||
|
priority = DEFAULT_PRIORITY
|
||||||
|
else:
|
||||||
|
print(f'Token {token} was given priority {priority}.')
|
||||||
|
|
||||||
|
# success, response, error_msg = generator(request_json_body)
|
||||||
|
event = priority_queue.put((request_json_body, client_ip, token, parameters), priority)
|
||||||
|
event.wait()
|
||||||
|
success, response, error_msg = event.data
|
||||||
|
|
||||||
if not success:
|
if not success:
|
||||||
if opts.mode == 'oobabooga':
|
if opts.mode == 'oobabooga':
|
||||||
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
|
backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error')
|
||||||
|
@ -99,8 +109,3 @@ def generate():
|
||||||
'error': 'the backend did not return valid JSON',
|
'error': 'the backend did not return valid JSON',
|
||||||
**response_json_body
|
**response_json_body
|
||||||
}), 200
|
}), 200
|
||||||
|
|
||||||
# @openai_bp.route('/chat/completions', methods=['POST'])
|
|
||||||
# def generate_openai():
|
|
||||||
# print(request.data)
|
|
||||||
# return '', 200
|
|
||||||
|
|
Reference in New Issue