import numpy as np from llm_server import opts from llm_server.cluster.backend import get_a_cluster_backend, get_backends_from_model, get_running_models from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis from llm_server.routes.queue import priority_queue from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers # TODO: give this a better name! def get_model_choices(regen: bool = False): if not regen: c = redis.getp('model_choices') if c: return c base_client_api = redis.get('base_client_api', dtype=str) running_models = get_running_models() model_choices = {} for model in running_models: b = get_backends_from_model(model) context_size = [] avg_gen_per_worker = [] for backend_url in b: backend_info = cluster_config.get_backend(backend_url) if backend_info.get('model_config'): context_size.append(backend_info['model_config']['max_position_embeddings']) if backend_info.get('average_generation_elapsed_sec'): avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec']) active_gen_workers = get_active_gen_workers(model) proompters_in_queue = priority_queue.len(model) if len(avg_gen_per_worker): average_generation_elapsed_sec = np.average(avg_gen_per_worker) else: average_generation_elapsed_sec = 0 estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers) if proompters_in_queue == 0 and active_gen_workers >= opts.concurrent_gens: # There will be a wait if the queue is empty but prompts are processing, but we don't # know how long. estimated_wait_sec = f"less than {estimated_wait_sec} seconds" else: estimated_wait_sec = f"{estimated_wait_sec} seconds" model_choices[model] = { 'client_api': f'https://{base_client_api}/{model}', 'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None, 'openai_client_api': f'https://{base_client_api}/openai/{model}' if opts.enable_openi_compatible_backend else 'disabled', 'backend_count': len(b), 'estimated_wait': estimated_wait_sec, 'queued': proompters_in_queue, 'processing': active_gen_workers, 'avg_generation_time': average_generation_elapsed_sec, } if len(context_size): model_choices[model]['context_size'] = min(context_size) # Python wants to sort lowercase vs. uppercase letters differently. model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper())) default_backend = get_a_cluster_backend() default_backend_dict = {} if default_backend: default_backend_info = cluster_config.get_backend(default_backend) default_context_size = default_backend_info['model_config']['max_position_embeddings'] default_average_generation_elapsed_sec = default_backend_info.get('average_generation_elapsed_sec') default_active_gen_workers = redis.get(f'active_gen_workers:{default_backend}', dtype=int, default=0) default_proompters_in_queue = priority_queue.len(default_backend_info['model']) default_estimated_wait_sec = calculate_wait_time(default_average_generation_elapsed_sec, default_proompters_in_queue, default_backend_info['concurrent_gens'], default_active_gen_workers) default_backend_dict = { 'client_api': f'https://{base_client_api}', 'ws_client_api': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, 'openai_client_api': f'https://{base_client_api}/openai' if opts.enable_openi_compatible_backend else 'disabled', 'estimated_wait': default_estimated_wait_sec, 'queued': default_proompters_in_queue, 'processing': default_active_gen_workers, 'context_size': default_context_size, 'hash': default_backend_info['hash'], 'model': default_backend_info['model'], 'avg_generation_time': default_average_generation_elapsed_sec, 'online': True } redis.setp('model_choices', (model_choices, default_backend_dict)) return model_choices, default_backend_dict