local-llm-server/llm_server/cluster/model_choices.py

93 lines
4.4 KiB
Python

import numpy as np
from llm_server import opts
from llm_server.cluster.backend import get_a_cluster_backend, get_backends_from_model, get_running_models
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers
# TODO: give this a better name!
def get_model_choices(regen: bool = False):
if not regen:
c = redis.getp('model_choices')
if c:
return c
base_client_api = redis.get('base_client_api', dtype=str)
running_models = get_running_models()
model_choices = {}
for model in running_models:
b = get_backends_from_model(model)
context_size = []
avg_gen_per_worker = []
for backend_url in b:
backend_info = cluster_config.get_backend(backend_url)
if backend_info.get('model_config'):
context_size.append(backend_info['model_config']['max_position_embeddings'])
if backend_info.get('average_generation_elapsed_sec'):
avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec'])
active_gen_workers = get_active_gen_workers(model)
proompters_in_queue = priority_queue.len(model)
if len(avg_gen_per_worker):
average_generation_elapsed_sec = np.average(avg_gen_per_worker)
else:
average_generation_elapsed_sec = 0
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
if proompters_in_queue == 0 and active_gen_workers >= opts.concurrent_gens:
# There will be a wait if the queue is empty but prompts are processing, but we don't
# know how long.
estimated_wait_sec = f"less than {estimated_wait_sec} seconds"
else:
estimated_wait_sec = f"{estimated_wait_sec} seconds"
model_choices[model] = {
'client_api': f'https://{base_client_api}/{model}',
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/{model}' if opts.enable_openi_compatible_backend else 'disabled',
'backend_count': len(b),
'estimated_wait': estimated_wait_sec,
'queued': proompters_in_queue,
'processing': active_gen_workers,
'avg_generation_time': average_generation_elapsed_sec,
}
if len(context_size):
model_choices[model]['context_size'] = min(context_size)
# Python wants to sort lowercase vs. uppercase letters differently.
model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper()))
default_backend = get_a_cluster_backend()
default_backend_dict = {}
if default_backend:
default_backend_info = cluster_config.get_backend(default_backend)
default_context_size = default_backend_info['model_config']['max_position_embeddings']
default_average_generation_elapsed_sec = default_backend_info.get('average_generation_elapsed_sec')
default_active_gen_workers = redis.get(f'active_gen_workers:{default_backend}', dtype=int, default=0)
default_proompters_in_queue = priority_queue.len(default_backend_info['model'])
default_estimated_wait_sec = calculate_wait_time(default_average_generation_elapsed_sec, default_proompters_in_queue, default_backend_info['concurrent_gens'], default_active_gen_workers)
default_backend_dict = {
'client_api': f'https://{base_client_api}',
'ws_client_api': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai' if opts.enable_openi_compatible_backend else 'disabled',
'estimated_wait': default_estimated_wait_sec,
'queued': default_proompters_in_queue,
'processing': default_active_gen_workers,
'context_size': default_context_size,
'hash': default_backend_info['hash'],
'model': default_backend_info['model'],
'avg_generation_time': default_average_generation_elapsed_sec,
'online': True
}
redis.setp('model_choices', (model_choices, default_backend_dict))
return model_choices, default_backend_dict