diff --git a/llm_server/database.py b/llm_server/database.py index 1edde9a..41a5fa1 100644 --- a/llm_server/database.py +++ b/llm_server/database.py @@ -45,8 +45,10 @@ def init_db(): conn.close() -def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, response_tokens: int = None): +def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, response_tokens: int = None, is_error: bool = False): prompt_tokens = len(tokenizer.encode(prompt)) + + # TODO: insert None for response tokens when error if not response_tokens: response_tokens = len(tokenizer.encode(response)) @@ -54,9 +56,15 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe # usually we want to insert a float. if gen_time: gen_time = round(gen_time, 3) + if is_error: + gen_time = None if not opts.log_prompts: - prompt = response = None + prompt = None + + if not opts.log_prompts and not is_error: + # TODO: test and verify this works as expected + response = None timestamp = int(time.time()) conn = sqlite3.connect(opts.database_path) diff --git a/llm_server/llm/hf_textgen/__init__.py b/llm_server/llm/hf_textgen/__init__.py index e69de29..c1e8529 100644 --- a/llm_server/llm/hf_textgen/__init__.py +++ b/llm_server/llm/hf_textgen/__init__.py @@ -0,0 +1 @@ +# https://huggingface.github.io/text-generation-inference diff --git a/llm_server/llm/hf_textgen/generate.py b/llm_server/llm/hf_textgen/generate.py index ee20c04..9b4c9a9 100644 --- a/llm_server/llm/hf_textgen/generate.py +++ b/llm_server/llm/hf_textgen/generate.py @@ -14,7 +14,7 @@ def prepare_json(json_data: dict): return { 'inputs': json_data.get('prompt', ''), 'parameters': { - 'max_new_tokens': json_data.get('max_new_tokens'), + 'max_new_tokens': min(json_data.get('max_new_tokens', opts.context_size), opts.context_size), 'repetition_penalty': json_data.get('repetition_penalty', None), 'seed': seed, 'stop': json_data.get('stopping_strings', []), diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 01d0220..cc02979 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -59,7 +59,6 @@ def generate(): else: event = None if not event: - log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), 429) backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') response_json_body = { 'results': [ @@ -68,6 +67,7 @@ def generate(): } ], } + log_prompt(client_ip, token, request_json_body['prompt'], backend_response, None, parameters, dict(request.headers), 429, is_error=True) return jsonify({ **response_json_body }), 200 @@ -88,7 +88,7 @@ def generate(): } ], } - log_prompt(client_ip, token, request_json_body['prompt'], '', None, parameters, dict(request.headers), response if response else 0) + log_prompt(client_ip, token, request_json_body['prompt'], backend_response, None, parameters, dict(request.headers), response if response else 0, is_error=True) return jsonify({ 'code': 500, 'error': 'failed to reach backend', @@ -110,6 +110,7 @@ def generate(): elif opts.mode == 'hf-textgen': backend_response = response_json_body.get('generated_text', '') if response_json_body.get('error'): + backend_err = True error_type = response_json_body.get('error_type') error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error' response_json_body = { @@ -132,7 +133,7 @@ def generate(): else: raise Exception redis.incr('proompts') - log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code if response else 0, response_json_body.get('details', {}).get('generated_tokens')) + log_prompt(client_ip, token, request_json_body['prompt'], backend_response if not backend_err else '', elapsed_time if not backend_err else None, parameters, dict(request.headers), response.status_code if response else 0, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) return jsonify({ **response_json_body }), 200 @@ -149,7 +150,7 @@ def generate(): } else: raise Exception - log_prompt(client_ip, token, request_json_body['prompt'], '', elapsed_time, parameters, dict(request.headers), response.status_code) + log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time, parameters, dict(request.headers), response.status_code, is_error=True) return jsonify({ 'code': 500, 'error': 'the backend did not return valid JSON',