fix invalid param error, add manual model name
This commit is contained in:
parent
5dd95875dd
commit
6152b1bb66
|
@ -15,7 +15,8 @@ config_default_vars = {
|
||||||
'show_total_output_tokens': True,
|
'show_total_output_tokens': True,
|
||||||
'simultaneous_requests_per_ip': 3,
|
'simultaneous_requests_per_ip': 3,
|
||||||
'show_backend_info': True,
|
'show_backend_info': True,
|
||||||
'max_new_tokens': 500
|
'max_new_tokens': 500,
|
||||||
|
'manual_model_name': False
|
||||||
}
|
}
|
||||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ class LLMBackend:
|
||||||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||||
# raise NotImplementedError
|
# raise NotImplementedError
|
||||||
|
|
||||||
def get_parameters(self, parameters) -> Union[dict, None]:
|
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||||
"""
|
"""
|
||||||
Validate and return the parameters for this backend.
|
Validate and return the parameters for this backend.
|
||||||
Lets you set defaults for specific backends.
|
Lets you set defaults for specific backends.
|
||||||
|
|
|
@ -78,7 +78,7 @@ class VLLMBackend(LLMBackend):
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# return False, e
|
# return False, e
|
||||||
|
|
||||||
def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]:
|
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||||
default_params = SamplingParams()
|
default_params = SamplingParams()
|
||||||
try:
|
try:
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
|
@ -91,8 +91,7 @@ class VLLMBackend(LLMBackend):
|
||||||
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
|
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(e)
|
return None, str(e).strip('.')
|
||||||
return None, e
|
|
||||||
return vars(sampling_params), None
|
return vars(sampling_params), None
|
||||||
|
|
||||||
# def transform_sampling_params(params: SamplingParams):
|
# def transform_sampling_params(params: SamplingParams):
|
||||||
|
|
|
@ -22,3 +22,4 @@ show_total_output_tokens = True
|
||||||
netdata_root = None
|
netdata_root = None
|
||||||
simultaneous_requests_per_ip = 3
|
simultaneous_requests_per_ip = 3
|
||||||
show_backend_info = True
|
show_backend_info = True
|
||||||
|
manual_model_name = None
|
||||||
|
|
|
@ -56,10 +56,6 @@ class OobaRequestHandler:
|
||||||
else:
|
else:
|
||||||
return self.request.remote_addr
|
return self.request.remote_addr
|
||||||
|
|
||||||
# def get_parameters(self):
|
|
||||||
# # TODO: make this a LLMBackend method
|
|
||||||
# return self.backend.get_parameters()
|
|
||||||
|
|
||||||
def get_priority(self):
|
def get_priority(self):
|
||||||
if self.token:
|
if self.token:
|
||||||
conn = sqlite3.connect(opts.database_path)
|
conn = sqlite3.connect(opts.database_path)
|
||||||
|
@ -85,22 +81,22 @@ class OobaRequestHandler:
|
||||||
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||||
|
|
||||||
def handle_request(self):
|
def handle_request(self):
|
||||||
|
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
||||||
|
|
||||||
request_valid_json, self.request_json_body = validate_json(self.request.data)
|
request_valid_json, self.request_json_body = validate_json(self.request.data)
|
||||||
if not request_valid_json:
|
if not request_valid_json:
|
||||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||||
|
|
||||||
self.get_parameters()
|
self.get_parameters()
|
||||||
|
|
||||||
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
|
||||||
|
|
||||||
request_valid, invalid_request_err_msg = self.validate_request()
|
|
||||||
if not self.parameters:
|
|
||||||
params_valid = False
|
params_valid = False
|
||||||
else:
|
request_valid = False
|
||||||
|
invalid_request_err_msg = None
|
||||||
|
if self.parameters:
|
||||||
params_valid = True
|
params_valid = True
|
||||||
|
request_valid, invalid_request_err_msg = self.validate_request()
|
||||||
|
|
||||||
if not request_valid or not params_valid:
|
if not request_valid or not params_valid:
|
||||||
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid]
|
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid and msg]
|
||||||
combined_error_message = ', '.join(error_messages)
|
combined_error_message = ', '.join(error_messages)
|
||||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
def handle_server_error(e):
|
||||||
|
print(e)
|
||||||
|
return {'error': True}, 500
|
|
@ -1,6 +1,7 @@
|
||||||
from flask import Blueprint, request
|
from flask import Blueprint, request
|
||||||
|
|
||||||
from ..helpers.http import require_api_key
|
from ..helpers.http import require_api_key
|
||||||
|
from ..server_error import handle_server_error
|
||||||
from ... import opts
|
from ... import opts
|
||||||
|
|
||||||
bp = Blueprint('v1', __name__)
|
bp = Blueprint('v1', __name__)
|
||||||
|
@ -18,4 +19,9 @@ def before_request():
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@bp.errorhandler(500)
|
||||||
|
def handle_error(e):
|
||||||
|
return handle_server_error(e)
|
||||||
|
|
||||||
|
|
||||||
from . import generate, info, proxy, generate_stream
|
from . import generate, info, proxy, generate_stream
|
||||||
|
|
|
@ -9,7 +9,7 @@ from ... import opts
|
||||||
@bp.route('/generate', methods=['POST'])
|
@bp.route('/generate', methods=['POST'])
|
||||||
def generate():
|
def generate():
|
||||||
request_valid_json, request_json_body = validate_json(request.data)
|
request_valid_json, request_json_body = validate_json(request.data)
|
||||||
if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')):
|
if not request_valid_json or not request_json_body.get('prompt'):
|
||||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||||
else:
|
else:
|
||||||
handler = OobaRequestHandler(request)
|
handler = OobaRequestHandler(request)
|
||||||
|
|
|
@ -94,7 +94,7 @@ def generate_stats():
|
||||||
'gatekeeper': 'none' if opts.auth_required is False else 'token',
|
'gatekeeper': 'none' if opts.auth_required is False else 'token',
|
||||||
'context_size': opts.context_size,
|
'context_size': opts.context_size,
|
||||||
'concurrent': opts.concurrent_gens,
|
'concurrent': opts.concurrent_gens,
|
||||||
'model': model_name,
|
'model': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||||
'mode': opts.mode,
|
'mode': opts.mode,
|
||||||
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
|
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
|
||||||
},
|
},
|
||||||
|
|
|
@ -4,6 +4,7 @@ from flask import jsonify, request
|
||||||
|
|
||||||
from . import bp
|
from . import bp
|
||||||
from ..cache import cache
|
from ..cache import cache
|
||||||
|
from ... import opts
|
||||||
from ...llm.info import get_running_model
|
from ...llm.info import get_running_model
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,8 +28,8 @@ def get_model():
|
||||||
if cached_response:
|
if cached_response:
|
||||||
return cached_response
|
return cached_response
|
||||||
|
|
||||||
model, error = get_running_model()
|
model_name, error = get_running_model()
|
||||||
if not model:
|
if not model_name:
|
||||||
response = jsonify({
|
response = jsonify({
|
||||||
'code': 502,
|
'code': 502,
|
||||||
'msg': 'failed to reach backend',
|
'msg': 'failed to reach backend',
|
||||||
|
@ -36,7 +37,7 @@ def get_model():
|
||||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||||
else:
|
else:
|
||||||
response = jsonify({
|
response = jsonify({
|
||||||
'result': model,
|
'result': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||||
'timestamp': int(time.time())
|
'timestamp': int(time.time())
|
||||||
}), 200
|
}), 200
|
||||||
cache.set(cache_key, response, timeout=60)
|
cache.set(cache_key, response, timeout=60)
|
||||||
|
|
|
@ -6,6 +6,8 @@ from threading import Thread
|
||||||
|
|
||||||
from flask import Flask, jsonify, render_template, request
|
from flask import Flask, jsonify, render_template, request
|
||||||
|
|
||||||
|
from llm_server.routes.server_error import handle_server_error
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import vllm
|
import vllm
|
||||||
except ModuleNotFoundError as e:
|
except ModuleNotFoundError as e:
|
||||||
|
@ -65,6 +67,7 @@ opts.netdata_root = config['netdata_root']
|
||||||
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
||||||
opts.show_backend_info = config['show_backend_info']
|
opts.show_backend_info = config['show_backend_info']
|
||||||
opts.max_new_tokens = config['max_new_tokens']
|
opts.max_new_tokens = config['max_new_tokens']
|
||||||
|
opts.manual_model_name = config['manual_model_name']
|
||||||
|
|
||||||
opts.verify_ssl = config['verify_ssl']
|
opts.verify_ssl = config['verify_ssl']
|
||||||
if not opts.verify_ssl:
|
if not opts.verify_ssl:
|
||||||
|
@ -145,7 +148,7 @@ def home():
|
||||||
llm_middleware_name=config['llm_middleware_name'],
|
llm_middleware_name=config['llm_middleware_name'],
|
||||||
analytics_tracking_code=analytics_tracking_code,
|
analytics_tracking_code=analytics_tracking_code,
|
||||||
info_html=info_html,
|
info_html=info_html,
|
||||||
current_model=running_model,
|
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
|
||||||
client_api=stats['endpoints']['blocking'],
|
client_api=stats['endpoints']['blocking'],
|
||||||
ws_client_api=stats['endpoints']['streaming'],
|
ws_client_api=stats['endpoints']['streaming'],
|
||||||
estimated_wait=estimated_wait_sec,
|
estimated_wait=estimated_wait_sec,
|
||||||
|
@ -169,8 +172,7 @@ def fallback(first=None, rest=None):
|
||||||
|
|
||||||
@app.errorhandler(500)
|
@app.errorhandler(500)
|
||||||
def server_error(e):
|
def server_error(e):
|
||||||
print(e)
|
return handle_server_error(e)
|
||||||
return {'error': True}, 500
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
Reference in New Issue