diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py index 61061bb..cadf86e 100644 --- a/llm_server/cluster/backend.py +++ b/llm_server/cluster/backend.py @@ -1,8 +1,10 @@ +from llm_server import opts from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle from llm_server.cluster.stores import redis_running_models from llm_server.llm.generator import generator from llm_server.llm.info import get_info +from llm_server.routes.helpers.model import estimate_model_size def test_backend(backend_url: str, test_prompt: bool = False): @@ -34,11 +36,19 @@ def get_backends(): status = b.get('online', False) priority = b['priority'] result[k] = {'status': status, 'priority': priority} - online_backends = sorted( - ((url, info) for url, info in backends.items() if info['online']), - key=lambda kv: -kv[1]['priority'], - reverse=True - ) + + if not opts.prioritize_by_size: + online_backends = sorted( + ((url, info) for url, info in backends.items() if info['online']), + key=lambda kv: -kv[1]['priority'], + reverse=True + ) + else: + online_backends = sorted( + ((url, info) for url, info in backends.items() if info['online']), + key=lambda kv: estimate_model_size(kv[1]['model_config']), + reverse=True + ) offline_backends = sorted( ((url, info) for url, info in backends.items() if not info['online']), key=lambda kv: -kv[1]['priority'], diff --git a/llm_server/config/config.py b/llm_server/config/config.py index 11092c0..54eb3ec 100644 --- a/llm_server/config/config.py +++ b/llm_server/config/config.py @@ -35,7 +35,8 @@ config_default_vars = { 'show_backends': True, 'cluster_workers': 30, 'background_homepage_cacher': True, - 'openai_moderation_timeout': 5 + 'openai_moderation_timeout': 5, + 'prioritize_by_size': False } config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name'] diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 6f9db8d..9a55a70 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -49,6 +49,7 @@ def load_config(config_path): opts.background_homepage_cacher = config['background_homepage_cacher'] opts.openai_moderation_timeout = config['openai_moderation_timeout'] opts.frontend_api_mode = config['frontend_api_mode'] + opts.prioritize_by_size = config['prioritize_by_size'] if opts.openai_expose_our_model and not opts.openai_api_key: print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') diff --git a/llm_server/opts.py b/llm_server/opts.py index 38542a8..5c32f05 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -37,3 +37,4 @@ show_backends = True cluster_workers = 30 background_homepage_cacher = True openai_moderation_timeout = 5 +prioritize_by_size = False \ No newline at end of file diff --git a/llm_server/routes/helpers/model.py b/llm_server/routes/helpers/model.py new file mode 100644 index 0000000..ca35867 --- /dev/null +++ b/llm_server/routes/helpers/model.py @@ -0,0 +1,13 @@ +def estimate_model_size(config: dict): + """ + Estimate the size of a model from its config. No idea if this is correct, + but it allows us to compare models. + :param config: + :return: + """ + vocab_size = config['vocab_size'] + hidden_size = config['hidden_size'] + num_hidden_layers = config['num_hidden_layers'] + intermediate_size = config['intermediate_size'] + total_params = (vocab_size * hidden_size) + (num_hidden_layers * ((hidden_size * intermediate_size * 4) + (hidden_size * hidden_size * 3))) + return int(total_params / 1e9) diff --git a/server.py b/server.py index 478a028..560f15d 100644 --- a/server.py +++ b/server.py @@ -24,6 +24,7 @@ from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.sock import init_socketio +# TODO: what happens when all backends are offline? What about the "online" key in the stats page? # TODO: redis SCAN vs KEYS?? # TODO: implement blind RRD controlled via header and only used when there is a queue on the primary backend(s) # TODO: is frequency penalty the same as ooba repetition penalty???