fix invalid param error, add manual model name

This commit is contained in:
Cyberes 2023-09-12 10:30:45 -06:00
parent 5dd95875dd
commit 6152b1bb66
11 changed files with 34 additions and 25 deletions

View File

@ -15,7 +15,8 @@ config_default_vars = {
'show_total_output_tokens': True, 'show_total_output_tokens': True,
'simultaneous_requests_per_ip': 3, 'simultaneous_requests_per_ip': 3,
'show_backend_info': True, 'show_backend_info': True,
'max_new_tokens': 500 'max_new_tokens': 500,
'manual_model_name': False
} }
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -11,7 +11,7 @@ class LLMBackend:
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]: # def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# raise NotImplementedError # raise NotImplementedError
def get_parameters(self, parameters) -> Union[dict, None]: def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
""" """
Validate and return the parameters for this backend. Validate and return the parameters for this backend.
Lets you set defaults for specific backends. Lets you set defaults for specific backends.

View File

@ -78,7 +78,7 @@ class VLLMBackend(LLMBackend):
# except Exception as e: # except Exception as e:
# return False, e # return False, e
def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]: def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
default_params = SamplingParams() default_params = SamplingParams()
try: try:
sampling_params = SamplingParams( sampling_params = SamplingParams(
@ -91,8 +91,7 @@ class VLLMBackend(LLMBackend):
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens) max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
) )
except ValueError as e: except ValueError as e:
print(e) return None, str(e).strip('.')
return None, e
return vars(sampling_params), None return vars(sampling_params), None
# def transform_sampling_params(params: SamplingParams): # def transform_sampling_params(params: SamplingParams):

View File

@ -22,3 +22,4 @@ show_total_output_tokens = True
netdata_root = None netdata_root = None
simultaneous_requests_per_ip = 3 simultaneous_requests_per_ip = 3
show_backend_info = True show_backend_info = True
manual_model_name = None

View File

@ -56,10 +56,6 @@ class OobaRequestHandler:
else: else:
return self.request.remote_addr return self.request.remote_addr
# def get_parameters(self):
# # TODO: make this a LLMBackend method
# return self.backend.get_parameters()
def get_priority(self): def get_priority(self):
if self.token: if self.token:
conn = sqlite3.connect(opts.database_path) conn = sqlite3.connect(opts.database_path)
@ -85,22 +81,22 @@ class OobaRequestHandler:
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
def handle_request(self): def handle_request(self):
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
request_valid_json, self.request_json_body = validate_json(self.request.data) request_valid_json, self.request_json_body = validate_json(self.request.data)
if not request_valid_json: if not request_valid_json:
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
self.get_parameters() self.get_parameters()
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
request_valid, invalid_request_err_msg = self.validate_request()
if not self.parameters:
params_valid = False params_valid = False
else: request_valid = False
invalid_request_err_msg = None
if self.parameters:
params_valid = True params_valid = True
request_valid, invalid_request_err_msg = self.validate_request()
if not request_valid or not params_valid: if not request_valid or not params_valid:
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid] error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid and msg]
combined_error_message = ', '.join(error_messages) combined_error_message = ', '.join(error_messages)
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error') err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True) log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)

View File

@ -0,0 +1,3 @@
def handle_server_error(e):
print(e)
return {'error': True}, 500

View File

@ -1,6 +1,7 @@
from flask import Blueprint, request from flask import Blueprint, request
from ..helpers.http import require_api_key from ..helpers.http import require_api_key
from ..server_error import handle_server_error
from ... import opts from ... import opts
bp = Blueprint('v1', __name__) bp = Blueprint('v1', __name__)
@ -18,4 +19,9 @@ def before_request():
return response return response
@bp.errorhandler(500)
def handle_error(e):
return handle_server_error(e)
from . import generate, info, proxy, generate_stream from . import generate, info, proxy, generate_stream

View File

@ -9,7 +9,7 @@ from ... import opts
@bp.route('/generate', methods=['POST']) @bp.route('/generate', methods=['POST'])
def generate(): def generate():
request_valid_json, request_json_body = validate_json(request.data) request_valid_json, request_json_body = validate_json(request.data)
if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')): if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else: else:
handler = OobaRequestHandler(request) handler = OobaRequestHandler(request)

View File

@ -94,7 +94,7 @@ def generate_stats():
'gatekeeper': 'none' if opts.auth_required is False else 'token', 'gatekeeper': 'none' if opts.auth_required is False else 'token',
'context_size': opts.context_size, 'context_size': opts.context_size,
'concurrent': opts.concurrent_gens, 'concurrent': opts.concurrent_gens,
'model': model_name, 'model': opts.manual_model_name if opts.manual_model_name else model_name,
'mode': opts.mode, 'mode': opts.mode,
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip, 'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
}, },

View File

@ -4,6 +4,7 @@ from flask import jsonify, request
from . import bp from . import bp
from ..cache import cache from ..cache import cache
from ... import opts
from ...llm.info import get_running_model from ...llm.info import get_running_model
@ -27,8 +28,8 @@ def get_model():
if cached_response: if cached_response:
return cached_response return cached_response
model, error = get_running_model() model_name, error = get_running_model()
if not model: if not model_name:
response = jsonify({ response = jsonify({
'code': 502, 'code': 502,
'msg': 'failed to reach backend', 'msg': 'failed to reach backend',
@ -36,7 +37,7 @@ def get_model():
}), 500 # return 500 so Cloudflare doesn't intercept us }), 500 # return 500 so Cloudflare doesn't intercept us
else: else:
response = jsonify({ response = jsonify({
'result': model, 'result': opts.manual_model_name if opts.manual_model_name else model_name,
'timestamp': int(time.time()) 'timestamp': int(time.time())
}), 200 }), 200
cache.set(cache_key, response, timeout=60) cache.set(cache_key, response, timeout=60)

View File

@ -6,6 +6,8 @@ from threading import Thread
from flask import Flask, jsonify, render_template, request from flask import Flask, jsonify, render_template, request
from llm_server.routes.server_error import handle_server_error
try: try:
import vllm import vllm
except ModuleNotFoundError as e: except ModuleNotFoundError as e:
@ -65,6 +67,7 @@ opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
opts.show_backend_info = config['show_backend_info'] opts.show_backend_info = config['show_backend_info']
opts.max_new_tokens = config['max_new_tokens'] opts.max_new_tokens = config['max_new_tokens']
opts.manual_model_name = config['manual_model_name']
opts.verify_ssl = config['verify_ssl'] opts.verify_ssl = config['verify_ssl']
if not opts.verify_ssl: if not opts.verify_ssl:
@ -145,7 +148,7 @@ def home():
llm_middleware_name=config['llm_middleware_name'], llm_middleware_name=config['llm_middleware_name'],
analytics_tracking_code=analytics_tracking_code, analytics_tracking_code=analytics_tracking_code,
info_html=info_html, info_html=info_html,
current_model=running_model, current_model=opts.manual_model_name if opts.manual_model_name else running_model,
client_api=stats['endpoints']['blocking'], client_api=stats['endpoints']['blocking'],
ws_client_api=stats['endpoints']['streaming'], ws_client_api=stats['endpoints']['streaming'],
estimated_wait=estimated_wait_sec, estimated_wait=estimated_wait_sec,
@ -169,8 +172,7 @@ def fallback(first=None, rest=None):
@app.errorhandler(500) @app.errorhandler(500)
def server_error(e): def server_error(e):
print(e) return handle_server_error(e)
return {'error': True}, 500
if __name__ == "__main__": if __name__ == "__main__":