diff --git a/llm_server/database.py b/llm_server/database.py index afa9519..9040074 100644 --- a/llm_server/database.py +++ b/llm_server/database.py @@ -29,13 +29,16 @@ def init_db(db_path): ) ''') c.execute(''' - CREATE TABLE token_auth - (token TEXT, type TEXT NOT NULL, uses INTEGER, max_uses INTEGER, expire INTEGER, disabled BOOLEAN default 0) + CREATE TABLE token_auth ( + token TEXT UNIQUE, + type TEXT NOT NULL, + priority INTEGER default 9999, + uses INTEGER default 0, + max_uses INTEGER, + expire INTEGER, + disabled BOOLEAN default 0 + ) ''') - # c.execute(''' - # CREATE TABLE leeches - # (url TEXT, online TEXT) - # ''') conn.commit() conn.close() diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py new file mode 100644 index 0000000..88d829f --- /dev/null +++ b/llm_server/llm/generator.py @@ -0,0 +1,12 @@ +from llm_server import opts + + +def generator(request_json_body): + if opts.mode == 'oobabooga': + from oobabooga.generate import generate + return generate(request_json_body) + elif opts.mode == 'hf-textgen': + from hf_textgen.generate import generate + return generate(request_json_body) + else: + raise Exception diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py new file mode 100644 index 0000000..c2b9a50 --- /dev/null +++ b/llm_server/routes/queue.py @@ -0,0 +1,47 @@ +import heapq +import threading + +from llm_server.llm.generator import generator + + +class PriorityQueue: + def __init__(self): + self._queue = [] + self._index = 0 + self._cv = threading.Condition() + + def put(self, item, priority): + event = DataEvent() + with self._cv: + heapq.heappush(self._queue, (-priority, self._index, item, event)) + self._index += 1 + self._cv.notify() + return event + + def get(self): + with self._cv: + while len(self._queue) == 0: + self._cv.wait() + return heapq.heappop(self._queue)[-1] + + +priority_queue = PriorityQueue() + + +class DataEvent(threading.Event): + def __init__(self): + super().__init__() + self.data = None + + +def worker(): + while True: + request_json_body, client_ip, token, parameters, event = priority_queue.get() + success, response, error_msg = generator(request_json_body) + event.data = (success, response, error_msg) + event.set() + + +def start_workers(num_workers: int): + for _ in range(3): + threading.Thread(target=worker).start() diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index b284d46..ff97124 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -1,3 +1,4 @@ +import sqlite3 import time from flask import jsonify, request @@ -6,25 +7,16 @@ from llm_server.routes.stats import SemaphoreCheckerThread, concurrent_semaphore from . import bp from ..cache import redis from ..helpers.client import format_sillytavern_err -from ..helpers.http import cache_control, validate_json +from ..helpers.http import validate_json +from ..queue import priority_queue from ... import opts from ...database import log_prompt from ...helpers import safe_list_get - -def generator(request_json_body): - if opts.mode == 'oobabooga': - from ...llm.oobabooga.generate import generate - return generate(request_json_body) - elif opts.mode == 'hf-textgen': - from ...llm.hf_textgen.generate import generate - return generate(request_json_body) - else: - raise Exception +DEFAULT_PRIORITY = 9999 @bp.route('/generate', methods=['POST']) -@cache_control(-1) def generate(): request_valid_json, request_json_body = validate_json(request.data) if not request_valid_json: @@ -40,12 +32,30 @@ def generate(): SemaphoreCheckerThread.recent_prompters[client_ip] = time.time() - token = request.headers.get('X-Api-Key') - parameters = request_json_body.copy() del parameters['prompt'] - success, response, error_msg = generator(request_json_body) + token = request.headers.get('X-Api-Key') + priority = None + if token: + conn = sqlite3.connect(opts.database_path) + cursor = conn.cursor() + cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,)) + result = cursor.fetchone() + if result: + priority = result[0] + conn.close() + + if priority is None: + priority = DEFAULT_PRIORITY + else: + print(f'Token {token} was given priority {priority}.') + + # success, response, error_msg = generator(request_json_body) + event = priority_queue.put((request_json_body, client_ip, token, parameters), priority) + event.wait() + success, response, error_msg = event.data + if not success: if opts.mode == 'oobabooga': backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error') @@ -99,8 +109,3 @@ def generate(): 'error': 'the backend did not return valid JSON', **response_json_body }), 200 - -# @openai_bp.route('/chat/completions', methods=['POST']) -# def generate_openai(): -# print(request.data) -# return '', 200