From fab7b7ccdd9437f0de0c1f438462e184b97ad088 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sat, 23 Sep 2023 21:17:13 -0600 Subject: [PATCH] active gen workers wait --- llm_server/routes/v1/generate_stats.py | 7 +++---- server.py | 7 ++++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 17b8bd4..75cd373 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -12,10 +12,9 @@ from llm_server.routes.stats import calculate_avg_gen_time, get_active_gen_worke def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers): - workers_running = gen_time_calc if active_gen_workers > 0 else 0 - if proompters_in_queue < concurrent_gens: + if active_gen_workers < concurrent_gens: return 0 - elif proompters_in_queue >= concurrent_gens: + elif active_gen_workers >= concurrent_gens: # Calculate how long it will take to complete the currently running gens and the queued requests. # If the proompters in the queue are equal to the number of workers, just use the calculated generation time. # Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round @@ -25,7 +24,7 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act # Regardless, this is the most accurate estimate we can get without tracking worker elapsed times. proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \ else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc) - return proompters_in_queue_wait_time + workers_running + return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0 elif proompters_in_queue == 0 and active_gen_workers == 0: # No queue, no workers return 0 diff --git a/server.py b/server.py index d9d01b7..5653be0 100644 --- a/server.py +++ b/server.py @@ -2,6 +2,7 @@ import os import sys from pathlib import Path from threading import Thread + import simplejson as json from flask import Flask, jsonify, render_template, request @@ -28,7 +29,7 @@ from llm_server.helpers import resolve_path from llm_server.llm.vllm.info import vllm_info from llm_server.routes.cache import cache, redis from llm_server.routes.queue import start_workers -from llm_server.routes.stats import SemaphoreCheckerThread, process_avg_gen_time +from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time from llm_server.routes.v1 import bp from llm_server.routes.v1.generate_stats import generate_stats from llm_server.stream import init_socketio @@ -97,7 +98,6 @@ if config['average_generation_time_mode'] not in ['database', 'minute']: sys.exit(1) opts.average_generation_time_mode = config['average_generation_time_mode'] - if opts.mode == 'oobabooga': raise NotImplementedError # llm_server.llm.tokenizer = OobaboogaBackend() @@ -142,7 +142,8 @@ def home(): else: running_model = opts.running_model - if stats['queue']['queued'] == 0 and stats['queue']['processing'] > 0: + active_gen_workers = get_active_gen_workers() + if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: # There will be a wait if the queue is empty but prompts are processing, but we don't # know how long. estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"