fix invalid param error, add manual model name

This commit is contained in:
Cyberes 2023-09-12 10:30:45 -06:00
parent 5dd95875dd
commit 6152b1bb66
11 changed files with 34 additions and 25 deletions

View File

@ -15,7 +15,8 @@ config_default_vars = {
'show_total_output_tokens': True,
'simultaneous_requests_per_ip': 3,
'show_backend_info': True,
'max_new_tokens': 500
'max_new_tokens': 500,
'manual_model_name': False
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -11,7 +11,7 @@ class LLMBackend:
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# raise NotImplementedError
def get_parameters(self, parameters) -> Union[dict, None]:
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
"""
Validate and return the parameters for this backend.
Lets you set defaults for specific backends.

View File

@ -78,7 +78,7 @@ class VLLMBackend(LLMBackend):
# except Exception as e:
# return False, e
def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]:
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
default_params = SamplingParams()
try:
sampling_params = SamplingParams(
@ -91,8 +91,7 @@ class VLLMBackend(LLMBackend):
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
)
except ValueError as e:
print(e)
return None, e
return None, str(e).strip('.')
return vars(sampling_params), None
# def transform_sampling_params(params: SamplingParams):

View File

@ -22,3 +22,4 @@ show_total_output_tokens = True
netdata_root = None
simultaneous_requests_per_ip = 3
show_backend_info = True
manual_model_name = None

View File

@ -56,10 +56,6 @@ class OobaRequestHandler:
else:
return self.request.remote_addr
# def get_parameters(self):
# # TODO: make this a LLMBackend method
# return self.backend.get_parameters()
def get_priority(self):
if self.token:
conn = sqlite3.connect(opts.database_path)
@ -85,22 +81,22 @@ class OobaRequestHandler:
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
def handle_request(self):
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
request_valid_json, self.request_json_body = validate_json(self.request.data)
if not request_valid_json:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
self.get_parameters()
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
request_valid, invalid_request_err_msg = self.validate_request()
if not self.parameters:
params_valid = False
else:
params_valid = False
request_valid = False
invalid_request_err_msg = None
if self.parameters:
params_valid = True
request_valid, invalid_request_err_msg = self.validate_request()
if not request_valid or not params_valid:
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid]
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid and msg]
combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)

View File

@ -0,0 +1,3 @@
def handle_server_error(e):
print(e)
return {'error': True}, 500

View File

@ -1,6 +1,7 @@
from flask import Blueprint, request
from ..helpers.http import require_api_key
from ..server_error import handle_server_error
from ... import opts
bp = Blueprint('v1', __name__)
@ -18,4 +19,9 @@ def before_request():
return response
@bp.errorhandler(500)
def handle_error(e):
return handle_server_error(e)
from . import generate, info, proxy, generate_stream

View File

@ -9,7 +9,7 @@ from ... import opts
@bp.route('/generate', methods=['POST'])
def generate():
request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')):
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
handler = OobaRequestHandler(request)

View File

@ -94,7 +94,7 @@ def generate_stats():
'gatekeeper': 'none' if opts.auth_required is False else 'token',
'context_size': opts.context_size,
'concurrent': opts.concurrent_gens,
'model': model_name,
'model': opts.manual_model_name if opts.manual_model_name else model_name,
'mode': opts.mode,
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
},

View File

@ -4,6 +4,7 @@ from flask import jsonify, request
from . import bp
from ..cache import cache
from ... import opts
from ...llm.info import get_running_model
@ -27,8 +28,8 @@ def get_model():
if cached_response:
return cached_response
model, error = get_running_model()
if not model:
model_name, error = get_running_model()
if not model_name:
response = jsonify({
'code': 502,
'msg': 'failed to reach backend',
@ -36,7 +37,7 @@ def get_model():
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
response = jsonify({
'result': model,
'result': opts.manual_model_name if opts.manual_model_name else model_name,
'timestamp': int(time.time())
}), 200
cache.set(cache_key, response, timeout=60)

View File

@ -6,6 +6,8 @@ from threading import Thread
from flask import Flask, jsonify, render_template, request
from llm_server.routes.server_error import handle_server_error
try:
import vllm
except ModuleNotFoundError as e:
@ -65,6 +67,7 @@ opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.show_backend_info = config['show_backend_info']
opts.max_new_tokens = config['max_new_tokens']
opts.manual_model_name = config['manual_model_name']
opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl:
@ -145,7 +148,7 @@ def home():
llm_middleware_name=config['llm_middleware_name'],
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
current_model=running_model,
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
client_api=stats['endpoints']['blocking'],
ws_client_api=stats['endpoints']['streaming'],
estimated_wait=estimated_wait_sec,
@ -169,8 +172,7 @@ def fallback(first=None, rest=None):
@app.errorhandler(500)
def server_error(e):
print(e)
return {'error': True}, 500
return handle_server_error(e)
if __name__ == "__main__":