fix invalid param error, add manual model name
This commit is contained in:
parent
5dd95875dd
commit
6152b1bb66
|
@ -15,7 +15,8 @@ config_default_vars = {
|
|||
'show_total_output_tokens': True,
|
||||
'simultaneous_requests_per_ip': 3,
|
||||
'show_backend_info': True,
|
||||
'max_new_tokens': 500
|
||||
'max_new_tokens': 500,
|
||||
'manual_model_name': False
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ class LLMBackend:
|
|||
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
|
||||
# raise NotImplementedError
|
||||
|
||||
def get_parameters(self, parameters) -> Union[dict, None]:
|
||||
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||
"""
|
||||
Validate and return the parameters for this backend.
|
||||
Lets you set defaults for specific backends.
|
||||
|
|
|
@ -78,7 +78,7 @@ class VLLMBackend(LLMBackend):
|
|||
# except Exception as e:
|
||||
# return False, e
|
||||
|
||||
def get_parameters(self, parameters) -> Tuple[dict | None, Exception | None]:
|
||||
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||
default_params = SamplingParams()
|
||||
try:
|
||||
sampling_params = SamplingParams(
|
||||
|
@ -91,8 +91,7 @@ class VLLMBackend(LLMBackend):
|
|||
max_tokens=parameters.get('max_new_tokens', default_params.max_tokens)
|
||||
)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
return None, e
|
||||
return None, str(e).strip('.')
|
||||
return vars(sampling_params), None
|
||||
|
||||
# def transform_sampling_params(params: SamplingParams):
|
||||
|
|
|
@ -22,3 +22,4 @@ show_total_output_tokens = True
|
|||
netdata_root = None
|
||||
simultaneous_requests_per_ip = 3
|
||||
show_backend_info = True
|
||||
manual_model_name = None
|
||||
|
|
|
@ -56,10 +56,6 @@ class OobaRequestHandler:
|
|||
else:
|
||||
return self.request.remote_addr
|
||||
|
||||
# def get_parameters(self):
|
||||
# # TODO: make this a LLMBackend method
|
||||
# return self.backend.get_parameters()
|
||||
|
||||
def get_priority(self):
|
||||
if self.token:
|
||||
conn = sqlite3.connect(opts.database_path)
|
||||
|
@ -85,22 +81,22 @@ class OobaRequestHandler:
|
|||
self.parameters, self.parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||
|
||||
def handle_request(self):
|
||||
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
||||
|
||||
request_valid_json, self.request_json_body = validate_json(self.request.data)
|
||||
if not request_valid_json:
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
|
||||
self.get_parameters()
|
||||
|
||||
SemaphoreCheckerThread.recent_prompters[self.client_ip] = time.time()
|
||||
|
||||
request_valid, invalid_request_err_msg = self.validate_request()
|
||||
if not self.parameters:
|
||||
params_valid = False
|
||||
else:
|
||||
params_valid = False
|
||||
request_valid = False
|
||||
invalid_request_err_msg = None
|
||||
if self.parameters:
|
||||
params_valid = True
|
||||
request_valid, invalid_request_err_msg = self.validate_request()
|
||||
|
||||
if not request_valid or not params_valid:
|
||||
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid]
|
||||
error_messages = [msg for valid, msg in [(request_valid, invalid_request_err_msg), (params_valid, self.parameters_invalid_msg)] if not valid and msg]
|
||||
combined_error_message = ', '.join(error_messages)
|
||||
err = format_sillytavern_err(f'Validation Error: {combined_error_message}.', 'error')
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), err, 0, self.parameters, dict(self.request.headers), 0, is_error=True)
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
def handle_server_error(e):
|
||||
print(e)
|
||||
return {'error': True}, 500
|
|
@ -1,6 +1,7 @@
|
|||
from flask import Blueprint, request
|
||||
|
||||
from ..helpers.http import require_api_key
|
||||
from ..server_error import handle_server_error
|
||||
from ... import opts
|
||||
|
||||
bp = Blueprint('v1', __name__)
|
||||
|
@ -18,4 +19,9 @@ def before_request():
|
|||
return response
|
||||
|
||||
|
||||
@bp.errorhandler(500)
|
||||
def handle_error(e):
|
||||
return handle_server_error(e)
|
||||
|
||||
|
||||
from . import generate, info, proxy, generate_stream
|
||||
|
|
|
@ -9,7 +9,7 @@ from ... import opts
|
|||
@bp.route('/generate', methods=['POST'])
|
||||
def generate():
|
||||
request_valid_json, request_json_body = validate_json(request.data)
|
||||
if not request_valid_json or not (request_json_body.get('prompt') or request_json_body.get('messages')):
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
handler = OobaRequestHandler(request)
|
||||
|
|
|
@ -94,7 +94,7 @@ def generate_stats():
|
|||
'gatekeeper': 'none' if opts.auth_required is False else 'token',
|
||||
'context_size': opts.context_size,
|
||||
'concurrent': opts.concurrent_gens,
|
||||
'model': model_name,
|
||||
'model': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||
'mode': opts.mode,
|
||||
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
|
||||
},
|
||||
|
|
|
@ -4,6 +4,7 @@ from flask import jsonify, request
|
|||
|
||||
from . import bp
|
||||
from ..cache import cache
|
||||
from ... import opts
|
||||
from ...llm.info import get_running_model
|
||||
|
||||
|
||||
|
@ -27,8 +28,8 @@ def get_model():
|
|||
if cached_response:
|
||||
return cached_response
|
||||
|
||||
model, error = get_running_model()
|
||||
if not model:
|
||||
model_name, error = get_running_model()
|
||||
if not model_name:
|
||||
response = jsonify({
|
||||
'code': 502,
|
||||
'msg': 'failed to reach backend',
|
||||
|
@ -36,7 +37,7 @@ def get_model():
|
|||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
else:
|
||||
response = jsonify({
|
||||
'result': model,
|
||||
'result': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
cache.set(cache_key, response, timeout=60)
|
||||
|
|
|
@ -6,6 +6,8 @@ from threading import Thread
|
|||
|
||||
from flask import Flask, jsonify, render_template, request
|
||||
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
|
||||
try:
|
||||
import vllm
|
||||
except ModuleNotFoundError as e:
|
||||
|
@ -65,6 +67,7 @@ opts.netdata_root = config['netdata_root']
|
|||
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
||||
opts.show_backend_info = config['show_backend_info']
|
||||
opts.max_new_tokens = config['max_new_tokens']
|
||||
opts.manual_model_name = config['manual_model_name']
|
||||
|
||||
opts.verify_ssl = config['verify_ssl']
|
||||
if not opts.verify_ssl:
|
||||
|
@ -145,7 +148,7 @@ def home():
|
|||
llm_middleware_name=config['llm_middleware_name'],
|
||||
analytics_tracking_code=analytics_tracking_code,
|
||||
info_html=info_html,
|
||||
current_model=running_model,
|
||||
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
|
||||
client_api=stats['endpoints']['blocking'],
|
||||
ws_client_api=stats['endpoints']['streaming'],
|
||||
estimated_wait=estimated_wait_sec,
|
||||
|
@ -169,8 +172,7 @@ def fallback(first=None, rest=None):
|
|||
|
||||
@app.errorhandler(500)
|
||||
def server_error(e):
|
||||
print(e)
|
||||
return {'error': True}, 500
|
||||
return handle_server_error(e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Reference in New Issue