log errors to database

This commit is contained in:
Cyberes 2023-08-29 14:48:33 -06:00
parent b44dfa2471
commit 23f3fcf579
4 changed files with 17 additions and 7 deletions

View File

@ -45,8 +45,10 @@ def init_db():
conn.close()
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, response_tokens: int = None):
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, response_tokens: int = None, is_error: bool = False):
prompt_tokens = len(tokenizer.encode(prompt))
# TODO: insert None for response tokens when error
if not response_tokens:
response_tokens = len(tokenizer.encode(response))
@ -54,9 +56,15 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
# usually we want to insert a float.
if gen_time:
gen_time = round(gen_time, 3)
if is_error:
gen_time = None
if not opts.log_prompts:
prompt = response = None
prompt = None
if not opts.log_prompts and not is_error:
# TODO: test and verify this works as expected
response = None
timestamp = int(time.time())
conn = sqlite3.connect(opts.database_path)

View File

@ -0,0 +1 @@
# https://huggingface.github.io/text-generation-inference

View File

@ -14,7 +14,7 @@ def prepare_json(json_data: dict):
return {
'inputs': json_data.get('prompt', ''),
'parameters': {
'max_new_tokens': json_data.get('max_new_tokens'),
'max_new_tokens': min(json_data.get('max_new_tokens', opts.context_size), opts.context_size),
'repetition_penalty': json_data.get('repetition_penalty', None),
'seed': seed,
'stop': json_data.get('stopping_strings', []),

View File

@ -59,7 +59,6 @@ def generate():
else:
event = None
if not event:
log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), 429)
backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error')
response_json_body = {
'results': [
@ -68,6 +67,7 @@ def generate():
}
],
}
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, None, parameters, dict(request.headers), 429, is_error=True)
return jsonify({
**response_json_body
}), 200
@ -88,7 +88,7 @@ def generate():
}
],
}
log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), response if response else 0)
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, None, parameters, dict(request.headers), response if response else 0, is_error=True)
return jsonify({
'code': 500,
'error': 'failed to reach backend',
@ -110,6 +110,7 @@ def generate():
elif opts.mode == 'hf-textgen':
backend_response = response_json_body.get('generated_text', '')
if response_json_body.get('error'):
backend_err = True
error_type = response_json_body.get('error_type')
error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error'
response_json_body = {
@ -132,7 +133,7 @@ def generate():
else:
raise Exception
redis.incr('proompts')
log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code if response else 0, response_json_body.get('details', {}).get('generated_tokens'))
log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code if response else 0, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({
**response_json_body
}), 200
@ -149,7 +150,7 @@ def generate():
}
else:
raise Exception
log_prompt(client_ip, token, request_json_body['prompt'], '', elapsed_time, parameters, dict(request.headers), response.status_code)
log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time, parameters, dict(request.headers), response.status_code, is_error=True)
return jsonify({
'code': 500,
'error': 'the backend did not return valid JSON',