display error messages in sillytavern

This commit is contained in:
Cyberes 2023-08-22 20:28:41 -06:00
parent 0d32db2dbd
commit a9b7a7a2c7
5 changed files with 80 additions and 19 deletions

View File

@ -3,3 +3,17 @@ from pathlib import Path
def resolve_path(*p: str): def resolve_path(*p: str):
return Path(*p).expanduser().resolve().absolute() return Path(*p).expanduser().resolve().absolute()
def safe_list_get(l, idx, default):
"""
https://stackoverflow.com/a/5125636
:param l:
:param idx:
:param default:
:return:
"""
try:
return l[idx]
except IndexError:
return default

View File

@ -9,3 +9,4 @@ database_path = './proxy-server.db'
auth_required = False auth_required = False
log_prompts = False log_prompts = False
frontend_api_client = '' frontend_api_client = ''
http_host = None

View File

@ -0,0 +1,9 @@
from llm_server import opts
def format_sillytavern_err(msg: str, level: str = 'info'):
return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {opts.http_host} ===
-> {level.upper()} <-
{msg}
```"""

View File

@ -1,6 +1,7 @@
from flask import Blueprint, request from flask import Blueprint, request
from ..helpers.http import require_api_key from ..helpers.http import require_api_key
from ... import opts
bp = Blueprint('v1', __name__) bp = Blueprint('v1', __name__)
@ -9,6 +10,8 @@ bp = Blueprint('v1', __name__)
@bp.before_request @bp.before_request
def before_request(): def before_request():
if not opts.http_host:
opts.http_host = request.headers.get("Host")
if request.endpoint != 'v1.get_stats': if request.endpoint != 'v1.get_stats':
response = require_api_key() response = require_api_key()
if response is not None: if response is not None:

View File

@ -1,10 +1,12 @@
from flask import jsonify, request from flask import jsonify, request
from . import bp
from llm_server.routes.stats import concurrent_semaphore, proompts from llm_server.routes.stats import concurrent_semaphore, proompts
from . import bp
from ..helpers.client import format_sillytavern_err
from ..helpers.http import cache_control, validate_json from ..helpers.http import cache_control, validate_json
from ... import opts from ... import opts
from ...database import log_prompt from ...database import log_prompt
from ...helpers import safe_list_get
def generator(request_json_body): def generator(request_json_body):
@ -26,40 +28,72 @@ def generate():
return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400
with concurrent_semaphore: with concurrent_semaphore:
if request.headers.get('cf-connecting-ip'):
client_ip = request.headers.get('cf-connecting-ip')
elif request.headers.get('x-forwarded-for'):
client_ip = request.headers.get('x-forwarded-for')
else:
client_ip = request.remote_addr
token = request.headers.get('X-Api-Key')
parameters = request_json_body.copy()
del parameters['prompt']
success, response, error_msg = generator(request_json_body) success, response, error_msg = generator(request_json_body)
if not success: if not success:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Failed to reach the backend: {error_msg}', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({ return jsonify({
'code': 500, 'code': 500,
'error': 'failed to reach backend' 'error': 'failed to reach backend',
}), 500 **response_json_body
}), 200
response_valid_json, response_json_body = validate_json(response) response_valid_json, response_json_body = validate_json(response)
if response_valid_json: if response_valid_json:
proompts.increment() proompts.increment()
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'Backend returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.token_limit}.', 'error')
response_json_body['results'][0]['text'] = backend_response
else:
raise Exception
# request.headers = {"host": "proxy.chub-archive.evulid.cc", "x-forwarded-proto": "https", "user-agent": "node-fetch/1.0 (+https://github.com/bitinn/node-fetch)", "cf-visitor": {"scheme": "https"}, "cf-ipcountry": "CH", "accept": "*/*", "accept-encoding": "gzip", log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
# "x-forwarded-for": "193.32.127.228", "cf-ray": "7fa72c6a6d5cbba7-FRA", "cf-connecting-ip": "193.32.127.228", "cdn-loop": "cloudflare", "content-type": "application/json", "content-length": "9039"}
if request.headers.get('cf-connecting-ip'): print(response_json_body)
client_ip = request.headers.get('cf-connecting-ip')
elif request.headers.get('x-forwarded-for'):
client_ip = request.headers.get('x-forwarded-for')
else:
client_ip = request.remote_addr
parameters = request_json_body.copy()
del parameters['prompt']
token = request.headers.get('X-Api-Key')
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], response_json_body['results'][0]['text'], parameters, dict(request.headers), response.status_code)
return jsonify({ return jsonify({
**response_json_body **response_json_body
}), 200 }), 200
else: else:
if opts.mode == 'oobabooga':
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
response_json_body = {
'results': [
{
'text': backend_response,
}
],
}
else:
raise Exception
log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code)
return jsonify({ return jsonify({
'code': 500, 'code': 500,
'error': 'failed to reach backend' 'error': 'the backend did not return valid JSON',
}), 500 **response_json_body
}), 200
# @openai_bp.route('/chat/completions', methods=['POST']) # @openai_bp.route('/chat/completions', methods=['POST'])
# def generate_openai(): # def generate_openai():