From 8593198216960cf1c22b2fca492fab49e3551f9e Mon Sep 17 00:00:00 2001 From: Cyberes Date: Wed, 20 Sep 2023 21:19:26 -0600 Subject: [PATCH] close mysql cursor --- llm_server/database/create.py | 1 + llm_server/database/database.py | 109 ++++++++++++++++----------- llm_server/routes/request_handler.py | 12 +-- 3 files changed, 71 insertions(+), 51 deletions(-) diff --git a/llm_server/database/create.py b/llm_server/database/create.py index dd6d458..359f551 100644 --- a/llm_server/database/create.py +++ b/llm_server/database/create.py @@ -38,4 +38,5 @@ def create_db(): ) ''') conn.commit() + cursor.close() diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 583cf2b..327916b 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -33,13 +33,16 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe timestamp = int(time.time()) conn = db_pool.connection() cursor = conn.cursor() - cursor.execute(""" - INSERT INTO prompts - (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, - (ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) - conn.commit() + try: + cursor.execute(""" + INSERT INTO prompts + (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + (ip, token, opts.running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) + conn.commit() + finally: + cursor.close() def is_valid_api_key(api_key): @@ -53,9 +56,10 @@ def is_valid_api_key(api_key): disabled = bool(disabled) if (uses is None or uses < max_uses) and (expire is None or expire > time.time()) and not disabled: return True + conn.commit() return False finally: - conn.commit() + cursor.close() def increment_uses(api_key): @@ -67,70 +71,81 @@ def increment_uses(api_key): if row is not None: cursor.execute("UPDATE token_auth SET uses = COALESCE(uses, 0) + 1 WHERE token = %s", (api_key,)) return True + conn.commit() return False finally: - conn.commit() + cursor.close() def get_number_of_rows(table_name): conn = db_pool.connection() cursor = conn.cursor() - cursor.execute(f'SELECT COUNT(*) FROM {table_name}') - result = cursor.fetchone() - conn.commit() - return result[0] + try: + cursor.execute(f'SELECT COUNT(*) FROM {table_name}') + result = cursor.fetchone() + return result[0] + finally: + cursor.close() def average_column(table_name, column_name): conn = db_pool.connection() cursor = conn.cursor() - cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}") - result = cursor.fetchone() - conn.commit() - return result[0] + try: + cursor.execute(f"SELECT AVG({column_name}) FROM {table_name}") + result = cursor.fetchone() + return result[0] + finally: + cursor.close() def average_column_for_model(table_name, column_name, model_name): conn = db_pool.connection() cursor = conn.cursor() - cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s", (model_name,)) - result = cursor.fetchone() - conn.commit() - return result[0] + try: + cursor.execute(f"SELECT AVG({column_name}) FROM {table_name} WHERE model = %s", (model_name,)) + result = cursor.fetchone() + return result[0] + finally: + cursor.close() def weighted_average_column_for_model(table_name, column_name, model_name, backend_name, backend_url, exclude_zeros: bool = False): conn = db_pool.connection() cursor = conn.cursor() - cursor.execute(f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC", (model_name, backend_name, backend_url,)) - results = cursor.fetchall() + try: + cursor.execute(f"SELECT {column_name}, id FROM {table_name} WHERE model = %s AND backend_mode = %s AND backend_url = %s ORDER BY id DESC", (model_name, backend_name, backend_url,)) + results = cursor.fetchall() - total_weight = 0 - weighted_sum = 0 - for i, (value, rowid) in enumerate(results): - if value is None or (exclude_zeros and value == 0): - continue - weight = i + 1 - total_weight += weight - weighted_sum += weight * value + total_weight = 0 + weighted_sum = 0 + for i, (value, rowid) in enumerate(results): + if value is None or (exclude_zeros and value == 0): + continue + weight = i + 1 + total_weight += weight + weighted_sum += weight * value - if total_weight > 0: - # Avoid division by zero - calculated_avg = weighted_sum / total_weight - else: - calculated_avg = 0 + if total_weight > 0: + # Avoid division by zero + calculated_avg = weighted_sum / total_weight + else: + calculated_avg = 0 - conn.commit() - return calculated_avg + return calculated_avg + finally: + cursor.close() def sum_column(table_name, column_name): conn = db_pool.connection() cursor = conn.cursor() - cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}") - result = cursor.fetchone() - conn.commit() - return result[0] if result[0] else 0 + try: + cursor.execute(f"SELECT SUM({column_name}) FROM {table_name}") + result = cursor.fetchone() + return result[0] if result[0] else 0 + finally: + cursor.close() def get_distinct_ips_24h(): @@ -138,7 +153,9 @@ def get_distinct_ips_24h(): past_24_hours = int(time.time()) - 24 * 60 * 60 conn = db_pool.connection() cursor = conn.cursor() - cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s", (past_24_hours,)) - result = cursor.fetchone() - conn.commit() - return result[0] if result else 0 + try: + cursor.execute("SELECT COUNT(DISTINCT ip) FROM prompts WHERE timestamp >= %s", (past_24_hours,)) + result = cursor.fetchone() + return result[0] if result else 0 + finally: + cursor.close() diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index b0371e8..783074f 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -44,12 +44,14 @@ class RequestHandler: if self.token: conn = db_pool.connection() cursor = conn.cursor() - cursor.execute("SELECT priority FROM token_auth WHERE token = %s", (self.token,)) - result = cursor.fetchone() + try: + cursor.execute("SELECT priority FROM token_auth WHERE token = %s", (self.token,)) + result = cursor.fetchone() - if result: - return result[0] - conn.commit() + if result: + return result[0] + finally: + cursor.close() return DEFAULT_PRIORITY def get_parameters(self):