From 41e622d19c02d42f0c181f2f3a7b81ab068ee496 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sat, 23 Sep 2023 20:55:49 -0600 Subject: [PATCH] fix two exceptions --- llm_server/database/database.py | 9 +++++++-- llm_server/routes/v1/generate_stream.py | 4 ++++ llm_server/threads.py | 6 ++++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 5f00829..d0f7e56 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -1,5 +1,6 @@ import json import time +import traceback import llm_server from llm_server import opts @@ -114,8 +115,12 @@ def weighted_average_column_for_model(table_name, column_name, model_name, backe conn = db_pool.connection() cursor = conn.cursor() try: - cursor.execute(f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC", (model_name, backend_name, backend_url,)) - results = cursor.fetchall() + try: + cursor.execute(f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC", (model_name, backend_name, backend_url,)) + results = cursor.fetchall() + except Exception: + traceback.print_exc() + return -1 total_weight = 0 weighted_sum = 0 diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 43511db..55f3404 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -70,6 +70,10 @@ def stream(ws): if partial_response.endswith(b'\x00'): json_str = partial_response[:-1].decode() # Remove the null character and decode the byte string to a string json_obj = json.loads(json_str) + if not len(json_obj['text'][0].split(input_prompt + generated_text)): + # ???? + continue + new = json_obj['text'][0].split(input_prompt + generated_text)[1] ws.send(json.dumps({ diff --git a/llm_server/threads.py b/llm_server/threads.py index 23da835..0e48383 100644 --- a/llm_server/threads.py +++ b/llm_server/threads.py @@ -47,13 +47,15 @@ class MainBackgroundThread(Thread): # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True) or 0 - redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) + if average_generation_elapsed_sec > -1: + redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) # overall = average_column_for_model('prompts', 'generation_time', opts.running_model) # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', opts.running_model, opts.mode, opts.backend_url, exclude_zeros=True) or 0 - redis.set('average_output_tokens', average_output_tokens) + if average_generation_elapsed_sec > -1: + redis.set('average_output_tokens', average_output_tokens) # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) # print(f'Weighted: {average_output_tokens}, overall: {overall}')