From edf13db324f3aa0c5020d39c98ceb23e6ac09cae Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sun, 17 Sep 2023 18:33:57 -0600 Subject: [PATCH] calculate estimateed wate time better --- llm_server/helpers.py | 5 ++++ llm_server/routes/openai_request_handler.py | 1 + llm_server/routes/v1/generate_stats.py | 31 ++++++++++++++++----- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/llm_server/helpers.py b/llm_server/helpers.py index 55df351..bf56fac 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -1,4 +1,5 @@ import json +import math from collections import OrderedDict from pathlib import Path @@ -57,3 +58,7 @@ def jsonify_pretty(json_dict: dict, status=200, indent=4, sort_keys=True): response.headers['mimetype'] = 'application/json' response.status_code = status return response + + +def round_up_base(n, base): + return math.ceil(n / base) * base diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index c2d5401..e5530e3 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -48,6 +48,7 @@ class OpenAIRequestHandler(RequestHandler): flagged = False flagged_categories = [] + # TODO: make this threaded for msg in msgs_to_check: flagged, categories = check_moderation_endpoint(msg) flagged_categories.extend(categories) diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 783997a..3b5ef58 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -3,7 +3,7 @@ from datetime import datetime from llm_server import opts from llm_server.database import get_distinct_ips_24h, sum_column -from llm_server.helpers import deep_sort +from llm_server.helpers import deep_sort, round_up_base from llm_server.llm.info import get_running_model from llm_server.netdata import get_power_states from llm_server.routes.cache import redis @@ -11,6 +11,27 @@ from llm_server.routes.queue import priority_queue from llm_server.routes.stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time +def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers): + workers_running = gen_time_calc if active_gen_workers > 0 else 0 + if proompters_in_queue > 0: + # Calculate how long it will take to complete the currently running gens and the queued requests. + # If the proompters in the queue are equal to the number of workers, just use the calculated generation time. + # Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round + # that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally, + # Add gen_time_calc to the time to account for the currently running generations. + # This assumes that all active workers will finish at the same time, which is unlikely. + # Regardless, this is the most accurate estimate we can get without tracking worker elapsed times. + proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \ + else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc) + workers_running + return proompters_in_queue_wait_time + elif proompters_in_queue == 0 and active_gen_workers == 0: + # No queue, no workers + return 0 + else: + # No queue + return gen_time_calc + + # TODO: have routes/__init__.py point to the latest API version generate_stats() def generate_stats(): @@ -42,11 +63,8 @@ def generate_stats(): # the backend knows that. So, let's just stick with the elapsed time. gen_time_calc = average_generation_time - estimated_wait_sec = ( - (gen_time_calc * proompters_in_queue) / opts.concurrent_gens # Calculate wait time for items in queue - ) + ( - active_gen_workers * gen_time_calc # Calculate wait time for in-process items - ) if estimated_avg_tps > 0 else 0 + estimated_wait_sec = calculate_wait_time(gen_time_calc, proompters_in_queue, opts.concurrent_gens, active_gen_workers) + elif opts.average_generation_time_mode == 'minute': average_generation_time = calculate_avg_gen_time() gen_time_calc = average_generation_time @@ -65,7 +83,6 @@ def generate_stats(): else: netdata_stats = {} - output = { 'stats': { 'proompters': {