don't use db pooling, add LLM-ST-Errors header to disable formatted errors

This commit is contained in:
Cyberes 2023-09-26 23:59:22 -06:00
parent 7456bbe085
commit aba2e5b9c0
14 changed files with 123 additions and 96 deletions

View File

@ -1,28 +1,28 @@
import pymysql import pymysql
from dbutils.pooled_db import PooledDB
class DatabaseConnection: class DatabaseConnection:
db_pool: PooledDB = None host: str = None
username: str = None
password: str = None
database: str = None
def init_db(self, host, username, password, database): def init_db(self, host, username, password, database):
self.db_pool = PooledDB( self.host = host
creator=pymysql, self.username = username
maxconnections=10, self.password = password
host=host, self.database = database
user=username,
password=password, def cursor(self):
database=database, db = pymysql.connect(
host=self.host,
user=self.username,
password=self.password,
database=self.database,
charset='utf8mb4', charset='utf8mb4',
autocommit=True, autocommit=True,
) )
return db.cursor()
# Test it.
conn = self.db_pool.connection()
del conn
def connection(self):
return self.db_pool.connection()
db_pool = DatabaseConnection() database = DatabaseConnection()

View File

@ -1,9 +1,8 @@
from llm_server.database.conn import db_pool from llm_server.database.conn import database
def create_db(): def create_db():
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
cursor.execute(''' cursor.execute('''
CREATE TABLE IF NOT EXISTS prompts ( CREATE TABLE IF NOT EXISTS prompts (
ip TEXT, ip TEXT,
@ -38,5 +37,4 @@ def create_db():
disabled BOOLEAN DEFAULT 0 disabled BOOLEAN DEFAULT 0
) )
''') ''')
conn.commit()
cursor.close() cursor.close()

View File

@ -4,7 +4,7 @@ import traceback
import llm_server import llm_server
from llm_server import opts from llm_server import opts
from llm_server.database.conn import db_pool from llm_server.database.conn import database
from llm_server.llm.vllm import tokenize from llm_server.llm.vllm import tokenize
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
@ -37,8 +37,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
running_model = redis.get('running_model', str, 'ERROR') running_model = redis.get('running_model', str, 'ERROR')
timestamp = int(time.time()) timestamp = int(time.time())
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
cursor.execute(""" cursor.execute("""
INSERT INTO prompts INSERT INTO prompts
@ -51,8 +50,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
def is_valid_api_key(api_key): def is_valid_api_key(api_key):
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,)) cursor.execute("SELECT token, uses, max_uses, expire, disabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone() row = cursor.fetchone()
@ -69,12 +67,10 @@ def is_valid_api_key(api_key):
def is_api_key_moderated(api_key): def is_api_key_moderated(api_key):
if not api_key: if not api_key:
return opts.openai_moderation_enabled return opts.openai_moderation_enabled
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,)) cursor.execute("SELECT openai_moderation_enabled FROM token_auth WHERE token = %s", (api_key,))
row = cursor.fetchone() row = cursor.fetchone()
print(bool(row[0]))
if row is not None: if row is not None:
return bool(row[0]) return bool(row[0])
return opts.openai_moderation_enabled return opts.openai_moderation_enabled
@ -83,8 +79,7 @@ def is_api_key_moderated(api_key):
def get_number_of_rows(table_name): def get_number_of_rows(table_name):
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") cursor.execute(f"SELECT COUNT(*) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
result = cursor.fetchone() result = cursor.fetchone()
@ -94,8 +89,7 @@ def get_number_of_rows(table_name):
def average_column(table_name, column_name): def average_column(table_name, column_name):
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
result = cursor.fetchone() result = cursor.fetchone()
@ -105,8 +99,7 @@ def average_column(table_name, column_name):
def average_column_for_model(table_name, column_name, model_name): def average_column_for_model(table_name, column_name, model_name):
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,)) cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s AND token NOT LIKE 'SYSTEM__%%' OR token IS NULL", (model_name,))
result = cursor.fetchone() result = cursor.fetchone()
@ -121,8 +114,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
else: else:
sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC" sql = f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL) ORDER BY id DESC"
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
try: try:
cursor.execute(sql, (model_name, backend_name, backend_url,)) cursor.execute(sql, (model_name, backend_name, backend_url,))
@ -152,8 +144,7 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe
def sum_column(table_name, column_name): def sum_column(table_name, column_name):
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL") cursor.execute(f"SELECT SUM({column_name}) FROM {table_name} WHERE token NOT LIKE 'SYSTEM__%%' OR token IS NULL")
result = cursor.fetchone() result = cursor.fetchone()
@ -165,8 +156,7 @@ def sum_column(table_name, column_name):
def get_distinct_ips_24h(): def get_distinct_ips_24h():
# Get the current time and subtract 24 hours (in seconds) # Get the current time and subtract 24 hours (in seconds)
past_24_hours = int(time.time()) - 24 * 60 * 60 past_24_hours = int(time.time()) - 24 * 60 * 60
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,)) cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s AND (token NOT LIKE 'SYSTEM__%%' OR token IS NULL)", (past_24_hours,))
result = cursor.fetchone() result = cursor.fetchone()
@ -176,8 +166,7 @@ def get_distinct_ips_24h():
def increment_token_uses(token): def increment_token_uses(token):
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,)) cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
finally: finally:

View File

@ -7,7 +7,7 @@ import traceback
from typing import Dict, List from typing import Dict, List
import tiktoken import tiktoken
from flask import jsonify from flask import jsonify, make_response
import llm_server import llm_server
from llm_server import opts from llm_server import opts
@ -36,7 +36,7 @@ def build_openai_response(prompt, response, model=None):
response_tokens = llm_server.llm.get_token_count(response) response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR') running_model = redis.get('running_model', str, 'ERROR')
return jsonify({ response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}", "id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion", "object": "chat.completion",
"created": int(time.time()), "created": int(time.time()),
@ -55,7 +55,12 @@ def build_openai_response(prompt, response, model=None):
"completion_tokens": response_tokens, "completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens "total_tokens": prompt_tokens + response_tokens
} }
}) }), 200)
stats = redis.get('proxy_stats', dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
def generate_oai_string(length=24): def generate_oai_string(length=24):

View File

@ -16,5 +16,4 @@ def tokenize(prompt: str) -> int:
return j['length'] return j['length']
except: except:
traceback.print_exc() traceback.print_exc()
print(prompt)
return len(tokenizer.encode(prompt)) + 10 return len(tokenizer.encode(prompt)) + 10

View File

@ -83,7 +83,7 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas
return True, data return True, data
elif isinstance(data, bytes): elif isinstance(data, bytes):
s = data.decode('utf-8') s = data.decode('utf-8')
return json.loads(s) return False, json.loads(s)
except Exception as e: except Exception as e:
return False, e return False, e
try: try:

View File

@ -28,11 +28,16 @@ class OobaRequestHandler(RequestHandler):
return backend_response return backend_response
def handle_ratelimited(self): def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) disable_st_error_formatting = self.request.headers.get('LLM-ST-Errors', False) == 'true'
return jsonify({ if disable_st_error_formatting:
'results': [{'text': backend_response}] return msg, 429
}), 429 else:
backend_response = format_sillytavern_err(msg, 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
return jsonify({
'results': [{'text': backend_response}]
}), 429
def handle_error(self, msg: str) -> Tuple[flask.Response, int]: def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
return jsonify({ return jsonify({

View File

@ -10,7 +10,7 @@ from ..cache import redis
from ..helpers.client import format_sillytavern_err from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler from ..openai_request_handler import OpenAIRequestHandler
from ...llm.openai.transform import build_openai_response, generate_oai_string from ...llm.openai.transform import build_openai_response, generate_oai_string, transform_messages_to_prompt
from ... import opts from ... import opts
from ...database.database import log_prompt from ...database.database import log_prompt
from ...llm.generator import generator from ...llm.generator import generator
@ -21,6 +21,7 @@ from ...llm.vllm import tokenize
@openai_bp.route('/chat/completions', methods=['POST']) @openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions(): def openai_chat_completions():
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
@ -42,7 +43,7 @@ def openai_chat_completions():
# TODO: simulate OAI here # TODO: simulate OAI here
raise Exception('TODO: simulate OAI here') raise Exception('TODO: simulate OAI here')
else: else:
handler.prompt = handler.transform_messages_to_prompt() handler.prompt = transform_messages_to_prompt(request_json_body['messages'])
msg_to_backend = { msg_to_backend = {
**handler.parameters, **handler.parameters,
'prompt': handler.prompt, 'prompt': handler.prompt,
@ -112,4 +113,7 @@ def openai_chat_completions():
except Exception as e: except Exception as e:
print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}') print(f'EXCEPTION on {request.url}!!!', f'{e.__class__.__name__}: {e}')
traceback.print_exc() traceback.print_exc()
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500 if disable_st_error_formatting:
return '500', 500
else:
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500

View File

@ -1,8 +1,7 @@
import time import time
import time
import traceback import traceback
from flask import jsonify, request from flask import jsonify, make_response, request
from . import openai_bp from . import openai_bp
from ..cache import redis from ..cache import redis
@ -11,7 +10,7 @@ from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ... import opts from ... import opts
from ...llm import get_token_count from ...llm import get_token_count
from ...llm.openai.transform import generate_oai_string from ...llm.openai.transform import build_openai_response, generate_oai_string
# TODO: add rate-limit headers? # TODO: add rate-limit headers?
@ -19,7 +18,6 @@ from ...llm.openai.transform import generate_oai_string
@openai_bp.route('/completions', methods=['POST']) @openai_bp.route('/completions', methods=['POST'])
def openai_completions(): def openai_completions():
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'): if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
@ -35,7 +33,7 @@ def openai_completions():
response_tokens = get_token_count(output) response_tokens = get_token_count(output)
running_model = redis.get('running_model', str, 'ERROR') running_model = redis.get('running_model', str, 'ERROR')
return jsonify({ response = make_response(jsonify({
"id": f"cmpl-{generate_oai_string(30)}", "id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": int(time.time()),
@ -53,8 +51,16 @@ def openai_completions():
"completion_tokens": response_tokens, "completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens "total_tokens": prompt_tokens + response_tokens
} }
}) }), 200)
stats = redis.get('proxy_stats', dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
except Exception as e: except Exception as e:
print(f'EXCEPTION on {request.url}!!!') print(f'EXCEPTION on {request.url}!!!')
print(traceback.format_exc()) print(traceback.format_exc())
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500 if disable_st_error_formatting:
return '500', 500
else:
return build_openai_response('', format_sillytavern_err(f'Server encountered exception.', 'error')), 500

View File

@ -71,9 +71,14 @@ class OpenAIRequestHandler(RequestHandler):
return backend_response, backend_response_status_code return backend_response, backend_response_status_code
def handle_ratelimited(self): def handle_ratelimited(self):
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') disable_st_error_formatting = self.request.headers.get('LLM-ST-Errors', False) == 'true'
log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True) if disable_st_error_formatting:
return build_openai_response(self.prompt, backend_response), 429 # TODO: format this like OpenAI does
return '429', 429
else:
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
log_prompt(ip=self.client_ip, token=self.token, prompt=self.request_json_body.get('prompt', ''), response=backend_response, gen_time=None, parameters=self.parameters, headers=dict(self.request.headers), backend_response_code=429, request_url=self.request.url, is_error=True)
return build_openai_response(self.prompt, backend_response), 429
def handle_error(self, msg: str) -> Tuple[flask.Response, int]: def handle_error(self, msg: str) -> Tuple[flask.Response, int]:
print(msg) print(msg)

View File

@ -5,7 +5,7 @@ import flask
from flask import Response, request from flask import Response, request
from llm_server import opts from llm_server import opts
from llm_server.database.conn import db_pool from llm_server.database.conn import database
from llm_server.database.database import log_prompt from llm_server.database.database import log_prompt
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
@ -65,8 +65,7 @@ class RequestHandler:
priority = DEFAULT_PRIORITY priority = DEFAULT_PRIORITY
simultaneous_ip = opts.simultaneous_requests_per_ip simultaneous_ip = opts.simultaneous_requests_per_ip
if self.token: if self.token:
conn = db_pool.connection() cursor = database.cursor()
cursor = conn.cursor()
try: try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,)) cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
result = cursor.fetchone() result = cursor.fetchone()
@ -120,10 +119,15 @@ class RequestHandler:
else: else:
# Otherwise, just grab the first and only one. # Otherwise, just grab the first and only one.
combined_error_message = invalid_request_err_msgs[0] + '.' combined_error_message = invalid_request_err_msgs[0] + '.'
backend_response = format_sillytavern_err(f'Validation Error: {combined_error_message}', 'error') msg = f'Validation Error: {combined_error_message}'
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
if disable_st_error_formatting:
backend_response = (Response(msg, 400), 400)
else:
backend_response = self.handle_error(format_sillytavern_err(msg, 'error'))
if do_log: if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response, 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
return False, self.handle_error(backend_response) return False, backend_response
return True, (None, 0) return True, (None, 0)
def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]: def generate_response(self, llm_request: dict) -> Tuple[Tuple[bool, flask.Response | None, str | None, float], Tuple[Response, int]]:
@ -163,9 +167,14 @@ class RequestHandler:
error_msg = 'Unknown error.' error_msg = 'Unknown error.'
else: else:
error_msg = error_msg.strip('.') + '.' error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, 'error')
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
if disable_st_error_formatting:
backend_response = (Response(error_msg, 400), 400)
else:
backend_response = format_sillytavern_err(error_msg, 'error')
log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, prompt, backend_response, None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
return (False, None, None, 0), self.handle_error(backend_response) return (False, None, None, 0), backend_response
# =============================================== # ===============================================
@ -183,9 +192,14 @@ class RequestHandler:
if return_json_err: if return_json_err:
error_msg = 'The backend did not return valid JSON.' error_msg = 'The backend did not return valid JSON.'
backend_response = format_sillytavern_err(error_msg, 'error') disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
if disable_st_error_formatting:
# TODO: how to format this
backend_response = (Response(error_msg, 400), 400)
else:
backend_response = self.handle_error(format_sillytavern_err(error_msg, 'error'))
log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, prompt, backend_response, elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
return (False, None, None, 0), self.handle_error(backend_response) return (False, None, None, 0), backend_response
# =============================================== # ===============================================

View File

@ -10,6 +10,7 @@ from ..ooba_request_handler import OobaRequestHandler
@bp.route('/generate', methods=['POST']) @bp.route('/generate', methods=['POST'])
def generate(): def generate():
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'): if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
@ -19,4 +20,7 @@ def generate():
except Exception as e: except Exception as e:
print(f'EXCEPTION on {request.url}!!!') print(f'EXCEPTION on {request.url}!!!')
print(traceback.format_exc()) print(traceback.format_exc())
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500 if disable_st_error_formatting:
return '500', 500
else:
return format_sillytavern_err(f'Server encountered exception.', 'error'), 500

View File

@ -25,6 +25,7 @@ def stream(ws):
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
message_num = 0 message_num = 0
while ws.connected: while ws.connected:
@ -134,22 +135,23 @@ def stream(ws):
thread.start() thread.start()
thread.join() thread.join()
except: except:
generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error') if not disable_st_error_formatting:
traceback.print_exc() generated_text = generated_text + '\n\n' + format_sillytavern_err('Encountered error while streaming.', 'error')
ws.send(json.dumps({ traceback.print_exc()
'event': 'text_stream', ws.send(json.dumps({
'message_num': message_num, 'event': 'text_stream',
'text': generated_text 'message_num': message_num,
})) 'text': generated_text
}))
def background_task_exception(): def background_task_exception():
generated_tokens = tokenize(generated_text) generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, None, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads # TODO: use async/await instead of threads
thread = threading.Thread(target=background_task_exception) thread = threading.Thread(target=background_task_exception)
thread.start() thread.start()
thread.join() thread.join()
try: try:
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'stream_end', 'event': 'stream_end',

View File

@ -16,7 +16,7 @@ import simplejson as json
from flask import Flask, jsonify, render_template, request from flask import Flask, jsonify, render_template, request
import llm_server import llm_server
from llm_server.database.conn import db_pool from llm_server.database.conn import database
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.database.database import get_number_of_rows from llm_server.database.database import get_number_of_rows
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
@ -25,13 +25,9 @@ from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio from llm_server.stream import init_socketio
# TODO: make sure prompts are logged even when the user cancels generation
# TODO: add some sort of loadbalancer to send requests to a group of backends # TODO: add some sort of loadbalancer to send requests to a group of backends
# TODO: use the current estimated wait time for ratelimit headers on openai
# TODO: accept a header to specify if the openai endpoint should return sillytavern-formatted errors
# TODO: allow setting concurrent gens per-backend # TODO: allow setting concurrent gens per-backend
# TODO: use first backend as default backend # TODO: use first backend as default backend
# TODO: allow disabling OpenAI moderation endpoint per-token
# TODO: allow setting specific simoltaneous IPs allowed per token # TODO: allow setting specific simoltaneous IPs allowed per token
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests # TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
@ -83,7 +79,7 @@ if not success:
if config['database_path'].startswith('./'): if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
db_pool.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db() create_db()
if config['mode'] not in ['oobabooga', 'vllm']: if config['mode'] not in ['oobabooga', 'vllm']: