calculate estimateed wate time better
This commit is contained in:
parent
7434ae1b5b
commit
edf13db324
|
@ -1,4 +1,5 @@
|
||||||
import json
|
import json
|
||||||
|
import math
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -57,3 +58,7 @@ def jsonify_pretty(json_dict: dict, status=200, indent=4, sort_keys=True):
|
||||||
response.headers['mimetype'] = 'application/json'
|
response.headers['mimetype'] = 'application/json'
|
||||||
response.status_code = status
|
response.status_code = status
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def round_up_base(n, base):
|
||||||
|
return math.ceil(n / base) * base
|
||||||
|
|
|
@ -48,6 +48,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
|
|
||||||
flagged = False
|
flagged = False
|
||||||
flagged_categories = []
|
flagged_categories = []
|
||||||
|
# TODO: make this threaded
|
||||||
for msg in msgs_to_check:
|
for msg in msgs_to_check:
|
||||||
flagged, categories = check_moderation_endpoint(msg)
|
flagged, categories = check_moderation_endpoint(msg)
|
||||||
flagged_categories.extend(categories)
|
flagged_categories.extend(categories)
|
||||||
|
|
|
@ -3,7 +3,7 @@ from datetime import datetime
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.database import get_distinct_ips_24h, sum_column
|
from llm_server.database import get_distinct_ips_24h, sum_column
|
||||||
from llm_server.helpers import deep_sort
|
from llm_server.helpers import deep_sort, round_up_base
|
||||||
from llm_server.llm.info import get_running_model
|
from llm_server.llm.info import get_running_model
|
||||||
from llm_server.netdata import get_power_states
|
from llm_server.netdata import get_power_states
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
|
@ -11,6 +11,27 @@ from llm_server.routes.queue import priority_queue
|
||||||
from llm_server.routes.stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time
|
from llm_server.routes.stats import SemaphoreCheckerThread, calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
|
||||||
|
workers_running = gen_time_calc if active_gen_workers > 0 else 0
|
||||||
|
if proompters_in_queue > 0:
|
||||||
|
# Calculate how long it will take to complete the currently running gens and the queued requests.
|
||||||
|
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
|
||||||
|
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
|
||||||
|
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
|
||||||
|
# Add gen_time_calc to the time to account for the currently running generations.
|
||||||
|
# This assumes that all active workers will finish at the same time, which is unlikely.
|
||||||
|
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
|
||||||
|
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
|
||||||
|
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc) + workers_running
|
||||||
|
return proompters_in_queue_wait_time
|
||||||
|
elif proompters_in_queue == 0 and active_gen_workers == 0:
|
||||||
|
# No queue, no workers
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
# No queue
|
||||||
|
return gen_time_calc
|
||||||
|
|
||||||
|
|
||||||
# TODO: have routes/__init__.py point to the latest API version generate_stats()
|
# TODO: have routes/__init__.py point to the latest API version generate_stats()
|
||||||
|
|
||||||
def generate_stats():
|
def generate_stats():
|
||||||
|
@ -42,11 +63,8 @@ def generate_stats():
|
||||||
# the backend knows that. So, let's just stick with the elapsed time.
|
# the backend knows that. So, let's just stick with the elapsed time.
|
||||||
gen_time_calc = average_generation_time
|
gen_time_calc = average_generation_time
|
||||||
|
|
||||||
estimated_wait_sec = (
|
estimated_wait_sec = calculate_wait_time(gen_time_calc, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
|
||||||
(gen_time_calc * proompters_in_queue) / opts.concurrent_gens # Calculate wait time for items in queue
|
|
||||||
) + (
|
|
||||||
active_gen_workers * gen_time_calc # Calculate wait time for in-process items
|
|
||||||
) if estimated_avg_tps > 0 else 0
|
|
||||||
elif opts.average_generation_time_mode == 'minute':
|
elif opts.average_generation_time_mode == 'minute':
|
||||||
average_generation_time = calculate_avg_gen_time()
|
average_generation_time = calculate_avg_gen_time()
|
||||||
gen_time_calc = average_generation_time
|
gen_time_calc = average_generation_time
|
||||||
|
@ -65,7 +83,6 @@ def generate_stats():
|
||||||
else:
|
else:
|
||||||
netdata_stats = {}
|
netdata_stats = {}
|
||||||
|
|
||||||
|
|
||||||
output = {
|
output = {
|
||||||
'stats': {
|
'stats': {
|
||||||
'proompters': {
|
'proompters': {
|
||||||
|
|
Loading…
Reference in New Issue