diff --git a/llm_server/database/conn.py b/llm_server/database/conn.py index be6c78f..25f3326 100644 --- a/llm_server/database/conn.py +++ b/llm_server/database/conn.py @@ -1,28 +1,28 @@ import pymysql -from dbutils.pooled_db import PooledDB class DatabaseConnection: - db_pool: PooledDB = None + host: str = None + username: str = None + password: str = None + database: str = None def init_db(self, host, username, password, database): - self.db_pool = PooledDB( - creator=pymysql, - maxconnections=10, - host=host, - user=username, - password=password, - database=database, + self.host = host + self.username = username + self.password = password + self.database = database + + def cursor(self): + db = pymysql.connect( + host=self.host, + user=self.username, + password=self.password, + database=self.database, charset='utf8mb4', autocommit=True, ) - - # Test it. - conn = self.db_pool.connection() - del conn - - def connection(self): - return self.db_pool.connection() + return db.cursor() -db_pool = DatabaseConnection() +database = DatabaseConnection() diff --git a/llm_server/database/create.py b/llm_server/database/create.py index 08ad8a2..c1788ae 100644 --- a/llm_server/database/create.py +++ b/llm_server/database/create.py @@ -1,9 +1,8 @@ -from llm_server.database.conn import db_pool +from llm_server.database.conn import database def create_db(): - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS prompts ( ip TEXT, @@ -38,5 +37,4 @@ def create_db(): disabled BOOLEAN DEFAULT 0 ) ''') - conn.commit() cursor.close() diff --git a/llm_server/database/database.py b/llm_server/database/database.py index c84fad7..ae1e70d 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -4,7 +4,7 @@ import traceback import llm_server from llm_server import opts -from llm_server.database.conn import db_pool +from llm_server.database.conn import database from llm_server.llm.vllm import tokenize from llm_server.routes.cache import redis @@ -37,8 +37,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe running_model = redis.get('running_model', str, 'ERROR') timestamp = int(time.time()) - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: cursor.execute(""" INSERT INTO prompts @@ -51,8 +50,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe def is_valid_api_key(api_key): - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,)) row = cursor.fetchone() @@ -69,12 +67,10 @@ def is_valid_api_key(api_key): def is_api_key_moderated(api_key): if not api_key: return opts.openai_moderation_enabled - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,)) row = cursor.fetchone() - print(bool(row[0])) if row is not None: return bool(row[0]) return opts.openai_moderation_enabled @@ -83,8 +79,7 @@ def is_api_key_moderated(api_key): def get_number_of_rows(table_name): - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") result = cursor.fetchone() @@ -94,8 +89,7 @@ def get_number_of_rows(table_name): def average_column(table_name, column_name): - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") result = cursor.fetchone() @@ -105,8 +99,7 @@ def average_column(table_name, column_name): def average_column_for_model(table_name, column_name, model_name): - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,)) result = cursor.fetchone() @@ -121,8 +114,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe else: sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC" - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: try: cursor.execute(sql, (model_name, backend_name, backend_url,)) @@ -152,8 +144,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe def sum_column(table_name, column_name): - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") result = cursor.fetchone() @@ -165,8 +156,7 @@ def sum_column(table_name, column_name): def get_distinct_ips_24h(): # Get the current time and subtract 24 hours (in seconds) past_24_hours = int(time.time()) - 24 * 60 * 60 - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,)) result = cursor.fetchone() @@ -176,8 +166,7 @@ def get_distinct_ips_24h(): def increment_token_uses(token): - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,)) finally: diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index 0cb0726..d5b64e3 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -7,7 +7,7 @@ import traceback from typing import Dict, List import tiktoken -from flask import jsonify +from flask import jsonify, make_response import llm_server from llm_server import opts @@ -36,7 +36,7 @@ def build_openai_response(prompt, response, model=None): response_tokens = llm_server.llm.get_token_count(response) running_model = redis.get('running_model', str, 'ERROR') - return jsonify({ + response = make_response(jsonify({ "id": f"chatcmpl-{generate_oai_string(30)}", "object": "chat.completion", "created": int(time.time()), @@ -55,7 +55,12 @@ def build_openai_response(prompt, response, model=None): "completion_tokens": response_tokens, "total_tokens": prompt_tokens + response_tokens } - }) + }), 200) + + stats = redis.get('proxy_stats', dict) + if stats: + response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] + return response def generate_oai_string(length=24): diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index 5df101f..5a8d09a 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -16,5 +16,4 @@ def tokenize(prompt: str) -> int: return j['length'] except: traceback.print_exc() - print(prompt) return len(tokenizer.encode(prompt)) + 10 diff --git a/llm_server/routes/helpers/http.py b/llm_server/routes/helpers/http.py index 24a1dc3..20b37b6 100644 --- a/llm_server/routes/helpers/http.py +++ b/llm_server/routes/helpers/http.py @@ -83,7 +83,7 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas return True, data elif isinstance(data, bytes): s = data.decode('utf-8') - return json.loads(s) + return False, json.loads(s) except Exception as e: return False, e try: diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index b04d473..11ab1e8 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -28,11 +28,16 @@ class OobaRequestHandler(RequestHandler): return backend_response def handle_ratelimited(self): - backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) - return jsonify({ - 'results': [{'text': backend_response}] - }), 429 + msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' + disable_st_error_formatting = self.request.headers.get('LLM-ST-Errors', False) == 'true' + if disable_st_error_formatting: + return msg, 429 + else: + backend_response = format_sillytavern_err(msg, 'error') + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) + return jsonify({ + 'results': [{'text': backend_response}] + }), 429 def handle_error(self, msg: str) -> Tuple[flask.Response, int]: return jsonify({ diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 0968cf3..2f05b65 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -10,7 +10,7 @@ from ..cache import redis from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler -from ...llm.openai.transform import build_openai_response, generate_oai_string +from ...llm.openai.transform import build_openai_response, generate_oai_string, transform_messages_to_prompt from ... import opts from ...database.database import log_prompt from ...llm.generator import generator @@ -21,6 +21,7 @@ from ...llm.vllm import tokenize @openai_bp.route('/chat/completions', methods=['POST']) def openai_chat_completions(): + disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 @@ -42,7 +43,7 @@ def openai_chat_completions(): # TODO: simulate OAI here raise Exception('TODO: simulate OAI here') else: - handler.prompt = handler.transform_messages_to_prompt() + handler.prompt = transform_messages_to_prompt(request_json_body['messages']) msg_to_backend = { **handler.parameters, 'prompt': handler.prompt, @@ -112,4 +113,7 @@ def openai_chat_completions(): except Exception as e: print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}') traceback.print_exc() - return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500 + if disable_st_error_formatting: + return '500', 500 + else: + return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500 diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 0070c5d..84e4542 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -1,8 +1,7 @@ import time -import time import traceback -from flask import jsonify, request +from flask import jsonify, make_response, request from . import openai_bp from ..cache import redis @@ -11,7 +10,7 @@ from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ... import opts from ...llm import get_token_count -from ...llm.openai.transform import generate_oai_string +from ...llm.openai.transform import build_openai_response, generate_oai_string # TODO: add rate-limit headers? @@ -19,7 +18,6 @@ from ...llm.openai.transform import generate_oai_string @openai_bp.route('/completions', methods=['POST']) def openai_completions(): disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' - request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 @@ -35,7 +33,7 @@ def openai_completions(): response_tokens = get_token_count(output) running_model = redis.get('running_model', str, 'ERROR') - return jsonify({ + response = make_response(jsonify({ "id": f"cmpl-{generate_oai_string(30)}", "object": "text_completion", "created": int(time.time()), @@ -53,8 +51,16 @@ def openai_completions(): "completion_tokens": response_tokens, "total_tokens": prompt_tokens + response_tokens } - }) + }), 200) + + stats = redis.get('proxy_stats', dict) + if stats: + response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] + return response except Exception as e: print(f'EXCEPTION on {request.url}!!!') print(traceback.format_exc()) - return format_sillytavern_err(f'Server encountered exception.', 'error'), 500 + if disable_st_error_formatting: + return '500', 500 + else: + return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500 diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index e1483ca..89832fc 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -71,9 +71,14 @@ class OpenAIRequestHandler(RequestHandler): return backend_response, backend_response_status_code def handle_ratelimited(self): - backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') - log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True) - return build_openai_response(self.prompt, backend_response), 429 + disable_st_error_formatting = self.request.headers.get('LLM-ST-Errors', False) == 'true' + if disable_st_error_formatting: + # TODO: format this like OpenAI does + return '429', 429 + else: + backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') + log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True) + return build_openai_response(self.prompt, backend_response), 429 def handle_error(self, msg: str) -> Tuple[flask.Response, int]: print(msg) diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 2b06748..16e3522 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -5,7 +5,7 @@ import flask from flask import Response, request from llm_server import opts -from llm_server.database.conn import db_pool +from llm_server.database.conn import database from llm_server.database.database import log_prompt from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend @@ -65,8 +65,7 @@ class RequestHandler: priority = DEFAULT_PRIORITY simultaneous_ip = opts.simultaneous_requests_per_ip if self.token: - conn = db_pool.connection() - cursor = conn.cursor() + cursor = database.cursor() try: cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,)) result = cursor.fetchone() @@ -120,10 +119,15 @@ class RequestHandler: else: # Otherwise, just grab the first and only one. combined_error_message = invalid_request_err_msgs[0] + '.' - backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}', 'error') + msg = f'Validation Error: {combined_error_message}' + disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' + if disable_st_error_formatting: + backend_response = (Response(msg, 400), 400) + else: + backend_response = self.handle_error(format_sillytavern_err(msg, 'error')) if do_log: log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) - return False, self.handle_error(backend_response) + return False, backend_response return True, (None, 0) def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]: @@ -163,9 +167,14 @@ class RequestHandler: error_msg = 'Unknown error.' else: error_msg = error_msg.strip('.') + '.' - backend_response = format_sillytavern_err(error_msg, 'error') + + disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' + if disable_st_error_formatting: + backend_response = (Response(error_msg, 400), 400) + else: + backend_response = format_sillytavern_err(error_msg, 'error') log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) - return (False, None, None, 0), self.handle_error(backend_response) + return (False, None, None, 0), backend_response # =============================================== @@ -183,9 +192,14 @@ class RequestHandler: if return_json_err: error_msg = 'The backend did not return valid JSON.' - backend_response = format_sillytavern_err(error_msg, 'error') + disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' + if disable_st_error_formatting: + # TODO: how to format this + backend_response = (Response(error_msg, 400), 400) + else: + backend_response = self.handle_error(format_sillytavern_err(error_msg, 'error')) log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) - return (False, None, None, 0), self.handle_error(backend_response) + return (False, None, None, 0), backend_response # =============================================== diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 53f410c..49fc43e 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -10,6 +10,7 @@ from ..ooba_request_handler import OobaRequestHandler @bp.route('/generate', methods=['POST']) def generate(): + disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 @@ -19,4 +20,7 @@ def generate(): except Exception as e: print(f'EXCEPTION on {request.url}!!!') print(traceback.format_exc()) - return format_sillytavern_err(f'Server encountered exception.', 'error'), 500 + if disable_st_error_formatting: + return '500', 500 + else: + return format_sillytavern_err(f'Server encountered exception.', 'error'), 500 diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index b2a80ca..789fb4f 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -25,6 +25,7 @@ def stream(ws): r_headers = dict(request.headers) r_url = request.url + disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' message_num = 0 while ws.connected: @@ -134,22 +135,23 @@ def stream(ws): thread.start() thread.join() except: - generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error') - traceback.print_exc() - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': generated_text - })) + if not disable_st_error_formatting: + generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error') + traceback.print_exc() + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': generated_text + })) - def background_task_exception(): - generated_tokens = tokenize(generated_text) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) + def background_task_exception(): + generated_tokens = tokenize(generated_text) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task_exception) - thread.start() - thread.join() + # TODO: use async/await instead of threads + thread = threading.Thread(target=background_task_exception) + thread.start() + thread.join() try: ws.send(json.dumps({ 'event': 'stream_end', diff --git a/server.py b/server.py index c8b1b0d..3f56cbb 100644 --- a/server.py +++ b/server.py @@ -16,7 +16,7 @@ import simplejson as json from flask import Flask, jsonify, render_template, request import llm_server -from llm_server.database.conn import db_pool +from llm_server.database.conn import database from llm_server.database.create import create_db from llm_server.database.database import get_number_of_rows from llm_server.llm import get_token_count @@ -25,13 +25,9 @@ from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.stream import init_socketio -# TODO: make sure prompts are logged even when the user cancels generation # TODO: add some sort of loadbalancer to send requests to a group of backends -# TODO: use the current estimated wait time for ratelimit headers on openai -# TODO: accept a header to specify if the openai endpoint should return sillytavern-formatted errors # TODO: allow setting concurrent gens per-backend # TODO: use first backend as default backend -# TODO: allow disabling OpenAI moderation endpoint per-token # TODO: allow setting specific simoltaneous IPs allowed per token # TODO: make sure log_prompt() is used everywhere, including errors and invalid requests @@ -83,7 +79,7 @@ if not success: if config['database_path'].startswith('./'): config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) -db_pool.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) +database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) create_db() if config['mode'] not in ['oobabooga', 'vllm']: