From e7b57cad7b6a5c4a93b2066de2058553da349bb6 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Thu, 28 Sep 2023 18:40:24 -0600 Subject: [PATCH 001/163] set up cluster config and basic background workers --- daemon.py | 2 +- llm_server/cluster/__init__.py | 0 llm_server/cluster/datastore.py | 0 llm_server/cluster/funcs/__init__.py | 0 llm_server/cluster/funcs/backend.py | 26 +++++++++++ llm_server/cluster/redis_config_cache.py | 42 +++++++++++++++++ llm_server/cluster/worker.py | 25 +++++++++++ llm_server/config/load.py | 23 ++++++---- llm_server/config/redis_config.py | 3 ++ .../{routes/cache.py => custom_redis.py} | 45 ++++++++++++++++--- llm_server/database/database.py | 2 +- llm_server/helpers.py | 2 +- llm_server/llm/__init__.py | 2 +- llm_server/llm/info.py | 11 ++--- llm_server/llm/llm_backend.py | 2 +- llm_server/llm/oobabooga/ooba_backend.py | 2 +- llm_server/llm/openai/transform.py | 2 +- llm_server/llm/vllm/generate.py | 2 +- llm_server/opts.py | 1 + llm_server/pre_fork.py | 2 +- llm_server/routes/helpers/client.py | 3 +- llm_server/routes/openai/chat_completions.py | 2 +- llm_server/routes/openai/completions.py | 5 +-- llm_server/routes/openai/info.py | 2 +- llm_server/routes/openai/models.py | 2 +- llm_server/routes/openai/simulated.py | 2 +- llm_server/routes/queue.py | 2 +- llm_server/routes/request_handler.py | 2 +- llm_server/routes/stats.py | 2 +- llm_server/routes/v1/generate_stats.py | 2 +- llm_server/routes/v1/generate_stream.py | 1 - llm_server/routes/v1/info.py | 2 +- llm_server/routes/v1/proxy.py | 4 +- llm_server/workers/blocking.py | 2 +- llm_server/workers/main.py | 3 +- llm_server/workers/printer.py | 2 +- llm_server/workers/recent.py | 2 +- requirements.txt | 1 - server.py | 9 ++-- test-cluster.py | 29 ++++++++++++ 40 files changed, 219 insertions(+), 54 deletions(-) create mode 100644 llm_server/cluster/__init__.py create mode 100644 llm_server/cluster/datastore.py create mode 100644 llm_server/cluster/funcs/__init__.py create mode 100644 llm_server/cluster/funcs/backend.py create mode 100644 llm_server/cluster/redis_config_cache.py create mode 100644 llm_server/cluster/worker.py create mode 100644 llm_server/config/redis_config.py rename llm_server/{routes/cache.py => custom_redis.py} (79%) create mode 100644 test-cluster.py diff --git a/daemon.py b/daemon.py index 20ec300..93e8d34 100644 --- a/daemon.py +++ b/daemon.py @@ -1,6 +1,6 @@ import time -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis try: import gevent.monkey diff --git a/llm_server/cluster/__init__.py b/llm_server/cluster/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_server/cluster/datastore.py b/llm_server/cluster/datastore.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_server/cluster/funcs/__init__.py b/llm_server/cluster/funcs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_server/cluster/funcs/backend.py b/llm_server/cluster/funcs/backend.py new file mode 100644 index 0000000..5b7b535 --- /dev/null +++ b/llm_server/cluster/funcs/backend.py @@ -0,0 +1,26 @@ +from llm_server.cluster.redis_config_cache import RedisClusterStore +from llm_server.llm.info import get_running_model + + +def test_backend(backend_url: str): + running_model, err = get_running_model(backend_url) + if not running_model: + return False + return True + + +def get_best_backends(): + cluster_config = RedisClusterStore('cluster_config') + backends = cluster_config.all() + result = {} + for k, v in backends.items(): + b = cluster_config.get_backend(k) + status = b['online'] + priority = b['priority'] + result[k] = {'status': status, 'priority': priority} + online_backends = sorted( + ((url, info) for url, info in backends.items() if info['online']), + key=lambda kv: kv[1]['priority'], + reverse=True + ) + return [url for url, info in online_backends] diff --git a/llm_server/cluster/redis_config_cache.py b/llm_server/cluster/redis_config_cache.py new file mode 100644 index 0000000..00a6a02 --- /dev/null +++ b/llm_server/cluster/redis_config_cache.py @@ -0,0 +1,42 @@ +import pickle + +from llm_server.custom_redis import RedisCustom + + +class RedisClusterStore: + def __init__(self, name: str, **kwargs): + self.name = name + self.config_redis = RedisCustom(name, **kwargs) + + def clear(self): + self.config_redis.flush() + + def load(self, config: dict): + for k, v in config.items(): + self.set_backend(k, v) + + def set_backend(self, name: str, values: dict): + self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()}) + self.set_backend_value(name, 'online', False) + + def set_backend_value(self, key: str, name: str, value): + self.config_redis.hset(key, name, pickle.dumps(value)) + + def get_backend(self, name: str): + r = self.config_redis.hgetall(name) + output = {} + for k, v in r.items(): + output[k.decode('utf8')] = pickle.loads(v) + return output + + def all(self): + keys = self.config_redis.keys('*') + if keys: + result = {} + for key in keys: + if key != f'{self.name}:____': + v = self.get_backend(key) + result[key] = v + return result + else: + return {} diff --git a/llm_server/cluster/worker.py b/llm_server/cluster/worker.py new file mode 100644 index 0000000..4aaaf6a --- /dev/null +++ b/llm_server/cluster/worker.py @@ -0,0 +1,25 @@ +import time +from threading import Thread + +from llm_server.cluster.funcs.backend import test_backend +from llm_server.cluster.redis_config_cache import RedisClusterStore + +cluster_config = RedisClusterStore('cluster_config') + + +def cluster_worker(): + while True: + threads = [] + for n, v in cluster_config.all().items(): + thread = Thread(target=check_backend, args=(n, v)) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + time.sleep(10) + + +def check_backend(n, v): + # Check if backends are online + online = test_backend(v['backend_url']) + cluster_config.set_backend_value(n, 'online', online) diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 64469b2..82afe81 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -5,22 +5,17 @@ import openai from llm_server import opts from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars +from llm_server.custom_redis import redis from llm_server.database.conn import database from llm_server.database.database import get_number_of_rows -from llm_server.helpers import resolve_path -from llm_server.routes.cache import redis -def load_config(config_path, script_path): +def load_config(config_path): config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) success, config, msg = config_loader.load_config() if not success: return success, config, msg - # Resolve relative directory to the directory of the script - if config['database_path'].startswith('./'): - config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) - if config['mode'] not in ['oobabooga', 'vllm']: print('Unknown mode:', config['mode']) sys.exit(1) @@ -34,7 +29,7 @@ def load_config(config_path, script_path): opts.context_size = config['token_limit'] opts.show_num_prompts = config['show_num_prompts'] opts.show_uptime = config['show_uptime'] - opts.backend_url = config['backend_url'].strip('/') + opts.cluster = config['cluster'] opts.show_total_output_tokens = config['show_total_output_tokens'] opts.netdata_root = config['netdata_root'] opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] @@ -81,3 +76,15 @@ def load_config(config_path, script_path): redis.set('backend_mode', opts.mode) return success, config, msg + + +def parse_backends(config): + if not config.get('cluster'): + return False + cluster = config.get('cluster') + config = {} + for item in cluster: + backend_url = item['backend_url'].strip('/') + item['backend_url'] = backend_url + config[backend_url] = item + return config diff --git a/llm_server/config/redis_config.py b/llm_server/config/redis_config.py new file mode 100644 index 0000000..06ab1d3 --- /dev/null +++ b/llm_server/config/redis_config.py @@ -0,0 +1,3 @@ +from llm_server.custom_redis import RedisCustom + +redis_config = RedisCustom('redis_config') diff --git a/llm_server/routes/cache.py b/llm_server/custom_redis.py similarity index 79% rename from llm_server/routes/cache.py rename to llm_server/custom_redis.py index d7046db..1b0b5f3 100644 --- a/llm_server/routes/cache.py +++ b/llm_server/custom_redis.py @@ -1,19 +1,20 @@ +import pickle import sys import traceback -from typing import Callable, List, Mapping, Union +from typing import Callable, List, Mapping, Union, Optional import redis as redis_pkg import simplejson as json from flask_caching import Cache from redis import Redis -from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT +from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT, PatternT flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) ONE_MONTH_SECONDS = 2678000 -class RedisWrapper: +class RedisCustom: """ A wrapper class to set prefixes to keys. """ @@ -40,7 +41,6 @@ class RedisWrapper: :param dtype: convert to this type :return: """ - d = self.redis.get(self._key(key)) if dtype and d: try: @@ -129,9 +129,35 @@ class RedisWrapper: ): return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt) + def hset( + self, + name: str, + key: Optional = None, + value=None, + mapping: Optional[dict] = None, + items: Optional[list] = None, + ): + return self.redis.hset(self._key(name), key, value, mapping, items) + def hkeys(self, name: str): return self.redis.hkeys(self._key(name)) + def hmget(self, name: str, keys: List, *args: List): + return self.redis.hmget(self._key(name), keys, *args) + + def hgetall(self, name: str): + return self.redis.hgetall(self._key(name)) + + def keys(self, pattern: PatternT = "*", **kwargs): + raw_keys = self.redis.keys(self._key(pattern), **kwargs) + keys = [] + for key in raw_keys: + p = key.decode('utf-8').split(':') + if len(p) > 2: + del p[0] + keys.append(':'.join(p)) + return keys + def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): return self.set(key, json.dumps(dict_value), ex=ex) @@ -142,6 +168,15 @@ class RedisWrapper: else: return json.loads(r.decode("utf-8")) + def setp(self, name, value): + self.redis.set(name, pickle.dumps(value)) + + def getp(self, name: str): + r = self.redis.get(name) + if r: + return pickle.load(r) + return r + def flush(self): flushed = [] for key in self.redis.scan_iter(f'{self.prefix}:*'): @@ -150,4 +185,4 @@ class RedisWrapper: return flushed -redis = RedisWrapper('local_llm') +redis = RedisCustom('local_llm') diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 9bfe578..3779c83 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -6,7 +6,7 @@ import llm_server from llm_server import opts from llm_server.database.conn import database from llm_server.llm.vllm import tokenize -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): diff --git a/llm_server/helpers.py b/llm_server/helpers.py index 44b436b..d6eb7d9 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -8,7 +8,7 @@ import simplejson as json from flask import make_response from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def resolve_path(*p: str): diff --git a/llm_server/llm/__init__.py b/llm_server/llm/__init__.py index 742b1a5..a08b25e 100644 --- a/llm_server/llm/__init__.py +++ b/llm_server/llm/__init__.py @@ -1,5 +1,5 @@ from llm_server.llm import oobabooga, vllm -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def get_token_count(prompt: str): diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py index 5a529ba..bedf3eb 100644 --- a/llm_server/llm/info.py +++ b/llm_server/llm/info.py @@ -3,20 +3,21 @@ import requests from llm_server import opts -def get_running_model(): - # TODO: cache the results for 1 min so we don't have to keep calling the backend - # TODO: only use one try/catch +def get_running_model(backend_url: str): + # TODO: remove this once we go to Redis + if not backend_url: + backend_url = opts.backend_url if opts.mode == 'oobabooga': try: - backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) + backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() return r_json['result'], None except Exception as e: return False, e elif opts.mode == 'vllm': try: - backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) + backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() return r_json['model'], None except Exception as e: diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index 1c11c17..153f66d 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -4,7 +4,7 @@ import flask from llm_server import opts from llm_server.llm import get_token_count -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis class LLMBackend: diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index 4336756..78f2190 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -3,7 +3,7 @@ from flask import jsonify from ..llm_backend import LLMBackend from ...database.database import log_prompt from ...helpers import safe_list_get -from ...routes.cache import redis +from llm_server.custom_redis import redis from ...routes.helpers.client import format_sillytavern_err from ...routes.helpers.http import validate_json diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index d5b64e3..8f1898e 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -12,7 +12,7 @@ from flask import jsonify, make_response import llm_server from llm_server import opts from llm_server.llm import get_token_count -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line. diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 1549f2e..308b1de 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -9,7 +9,7 @@ import requests import llm_server from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis # TODO: make the VLMM backend return TPS and time elapsed diff --git a/llm_server/opts.py b/llm_server/opts.py index de23c7a..5eec1fa 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -37,3 +37,4 @@ openai_moderation_workers = 10 openai_org_name = 'OpenAI' openai_silent_trim = False openai_moderation_enabled = True +cluster = {} diff --git a/llm_server/pre_fork.py b/llm_server/pre_fork.py index f3ea0f4..21da08e 100644 --- a/llm_server/pre_fork.py +++ b/llm_server/pre_fork.py @@ -2,7 +2,7 @@ import sys from redis import Redis -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis from llm_server.routes.v1.generate_stats import generate_stats diff --git a/llm_server/routes/helpers/client.py b/llm_server/routes/helpers/client.py index 48e721e..d97e9c5 100644 --- a/llm_server/routes/helpers/client.py +++ b/llm_server/routes/helpers/client.py @@ -1,5 +1,4 @@ -from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def format_sillytavern_err(msg: str, level: str = 'info'): diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index cc27dce..a289c78 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -6,7 +6,7 @@ import traceback from flask import Response, jsonify, request from . import openai_bp -from ..cache import redis +from llm_server.custom_redis import redis from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ... import opts diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 503f628..05ac7a5 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -4,13 +4,12 @@ import traceback from flask import jsonify, make_response, request from . import openai_bp -from ..cache import redis -from ..helpers.client import format_sillytavern_err +from llm_server.custom_redis import redis from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ... import opts from ...llm import get_token_count -from ...llm.openai.transform import build_openai_response, generate_oai_string +from ...llm.openai.transform import generate_oai_string # TODO: add rate-limit headers? diff --git a/llm_server/routes/openai/info.py b/llm_server/routes/openai/info.py index 54959ae..4fc578a 100644 --- a/llm_server/routes/openai/info.py +++ b/llm_server/routes/openai/info.py @@ -1,7 +1,7 @@ from flask import Response from . import openai_bp -from ..cache import flask_cache +from llm_server.custom_redis import flask_cache from ... import opts diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py index 47223e7..4f732e6 100644 --- a/llm_server/routes/openai/models.py +++ b/llm_server/routes/openai/models.py @@ -4,7 +4,7 @@ import requests from flask import jsonify from . import openai_bp -from ..cache import ONE_MONTH_SECONDS, flask_cache, redis +from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis from ..stats import server_start_time from ... import opts from ...helpers import jsonify_pretty diff --git a/llm_server/routes/openai/simulated.py b/llm_server/routes/openai/simulated.py index f626490..301e8de 100644 --- a/llm_server/routes/openai/simulated.py +++ b/llm_server/routes/openai/simulated.py @@ -1,7 +1,7 @@ from flask import jsonify from . import openai_bp -from ..cache import ONE_MONTH_SECONDS, flask_cache +from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache from ...llm.openai.transform import generate_oai_string from ..stats import server_start_time diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 84cc614..8d85319 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -6,7 +6,7 @@ from uuid import uuid4 from redis import Redis from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def increment_ip_count(client_ip: str, redis_key): diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 4b1f640..bb64859 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -11,7 +11,7 @@ from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index a6e9e17..a0846d9 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -1,6 +1,6 @@ from datetime import datetime -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis # proompters_5_min = 0 # concurrent_semaphore = Semaphore(concurrent_gens) diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index e144099..b2dd527 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -6,7 +6,7 @@ from llm_server.database.database import get_distinct_ips_24h, sum_column from llm_server.helpers import deep_sort, round_up_base from llm_server.llm.info import get_running_model from llm_server.netdata import get_power_states -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis from llm_server.routes.queue import priority_queue from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 0fc8f40..45fbf12 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -6,7 +6,6 @@ from typing import Union from flask import request -from ..cache import redis from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index 2091118..7cdbf0f 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -4,7 +4,7 @@ from flask import jsonify, request from . import bp from ..auth import requires_auth -from ..cache import flask_cache +from llm_server.custom_redis import flask_cache from ... import opts from ...llm.info import get_running_model diff --git a/llm_server/routes/v1/proxy.py b/llm_server/routes/v1/proxy.py index 4349526..5ffd194 100644 --- a/llm_server/routes/v1/proxy.py +++ b/llm_server/routes/v1/proxy.py @@ -1,8 +1,6 @@ -from flask import jsonify - from . import bp from .generate_stats import generate_stats -from ..cache import flask_cache +from llm_server.custom_redis import flask_cache from ...helpers import jsonify_pretty diff --git a/llm_server/workers/blocking.py b/llm_server/workers/blocking.py index 27b0815..dcf0047 100644 --- a/llm_server/workers/blocking.py +++ b/llm_server/workers/blocking.py @@ -3,7 +3,7 @@ import time from llm_server import opts from llm_server.llm.generator import generator -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue diff --git a/llm_server/workers/main.py b/llm_server/workers/main.py index 747f699..f592c5e 100644 --- a/llm_server/workers/main.py +++ b/llm_server/workers/main.py @@ -1,10 +1,9 @@ import time -from threading import Thread from llm_server import opts from llm_server.database.database import weighted_average_column_for_model from llm_server.llm.info import get_running_model -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def main_background_thread(): diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py index cb0f032..6a33835 100644 --- a/llm_server/workers/printer.py +++ b/llm_server/workers/printer.py @@ -1,7 +1,7 @@ import logging import time -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis from llm_server.routes.queue import priority_queue logger = logging.getLogger('console_printer') diff --git a/llm_server/workers/recent.py b/llm_server/workers/recent.py index d378a87..c6158d6 100644 --- a/llm_server/workers/recent.py +++ b/llm_server/workers/recent.py @@ -1,6 +1,6 @@ import time -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def recent_prompters_thread(): diff --git a/requirements.txt b/requirements.txt index 9b0c8eb..7b49eed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,3 @@ websockets~=11.0.3 basicauth~=1.0.0 openai~=0.28.0 urllib3~=2.0.4 -celery[redis] diff --git a/server.py b/server.py index 06482d4..3c334bc 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,7 @@ except ImportError: pass from llm_server.pre_fork import server_startup -from llm_server.config.load import load_config +from llm_server.config.load import load_config, parse_backends import os import sys from pathlib import Path @@ -36,6 +36,7 @@ from llm_server.stream import init_socketio # TODO: add a way to cancel VLLM gens. Maybe use websockets? # TODO: use coloredlogs # TODO: need to update opts. for workers +# TODO: add a healthcheck to VLLM # Lower priority # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens @@ -63,7 +64,7 @@ import config from llm_server import opts from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info -from llm_server.routes.cache import RedisWrapper, flask_cache +from llm_server.custom_redis import RedisCustom, flask_cache from llm_server.llm import redis from llm_server.routes.stats import get_active_gen_workers from llm_server.routes.v1.generate_stats import generate_stats @@ -89,9 +90,11 @@ if not success: database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) create_db() -llm_server.llm.redis = RedisWrapper('local_llm') +llm_server.llm.redis = RedisCustom('local_llm') create_db() +x = parse_backends(config) +print(x) # print(app.url_map) diff --git a/test-cluster.py b/test-cluster.py new file mode 100644 index 0000000..531892b --- /dev/null +++ b/test-cluster.py @@ -0,0 +1,29 @@ +try: + import gevent.monkey + + gevent.monkey.patch_all() +except ImportError: + pass + +import time +from threading import Thread + +from llm_server.cluster.funcs.backend import get_best_backends +from llm_server.cluster.redis_config_cache import RedisClusterStore +from llm_server.cluster.worker import cluster_worker +from llm_server.config.load import parse_backends, load_config + +success, config, msg = load_config('./config/config.yml').resolve().absolute() + +cluster_config = RedisClusterStore('cluster_config') +cluster_config.clear() +cluster_config.load(parse_backends(config)) + +t = Thread(target=cluster_worker) +t.daemon = True +t.start() + +while True: + x = get_best_backends() + print(x) + time.sleep(3) -- 2.34.1 From 624ca74ce5f31bc81292c6507eb7830081d3e625 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Fri, 29 Sep 2023 00:09:44 -0600 Subject: [PATCH 002/163] mvp --- daemon.py | 18 +--- llm_server/cluster/backend.py | 71 ++++++++++++ llm_server/cluster/cluster_config.py | 3 + llm_server/cluster/datastore.py | 0 llm_server/cluster/funcs/__init__.py | 0 llm_server/cluster/funcs/backend.py | 26 ----- llm_server/cluster/redis_config_cache.py | 12 ++- llm_server/cluster/redis_cycle.py | 21 ++++ llm_server/cluster/stores.py | 3 + llm_server/cluster/worker.py | 20 ++-- llm_server/config/config.py | 3 +- llm_server/config/load.py | 1 + llm_server/custom_redis.py | 34 ++++-- llm_server/database/database.py | 91 +++++++++------- llm_server/helpers.py | 2 +- llm_server/integer.py | 12 --- llm_server/llm/__init__.py | 2 +- llm_server/llm/generator.py | 4 +- llm_server/llm/info.py | 10 +- llm_server/llm/llm_backend.py | 2 +- llm_server/llm/openai/transform.py | 4 +- llm_server/llm/vllm/generate.py | 12 +-- llm_server/llm/vllm/vllm_backend.py | 12 +-- llm_server/opts.py | 2 +- llm_server/pre_fork.py | 2 +- llm_server/routes/helpers/client.py | 8 +- llm_server/routes/ooba_request_handler.py | 4 +- llm_server/routes/openai/chat_completions.py | 10 +- llm_server/routes/openai/completions.py | 4 +- llm_server/routes/openai/models.py | 3 +- llm_server/routes/queue.py | 2 +- llm_server/routes/request_handler.py | 22 ++-- llm_server/routes/stats.py | 23 ---- llm_server/routes/v1/generate_stats.py | 101 +++++++++++------- llm_server/routes/v1/generate_stream.py | 25 ++--- llm_server/routes/v1/info.py | 40 +++---- llm_server/{stream.py => sock.py} | 0 llm_server/workers/app.py | 35 ------ .../workers/{blocking.py => inferencer.py} | 8 +- llm_server/workers/main.py | 55 ---------- llm_server/workers/mainer.py | 56 ++++++++++ llm_server/workers/{recent.py => recenter.py} | 0 llm_server/workers/threader.py | 50 +++++++++ llm_server/workers/threads.py | 9 -- gunicorn.py => other/gunicorn.py | 5 + server.py | 49 +++++---- test-cluster.py | 20 +++- 47 files changed, 506 insertions(+), 390 deletions(-) create mode 100644 llm_server/cluster/backend.py create mode 100644 llm_server/cluster/cluster_config.py delete mode 100644 llm_server/cluster/datastore.py delete mode 100644 llm_server/cluster/funcs/__init__.py delete mode 100644 llm_server/cluster/funcs/backend.py create mode 100644 llm_server/cluster/redis_cycle.py create mode 100644 llm_server/cluster/stores.py delete mode 100644 llm_server/integer.py rename llm_server/{stream.py => sock.py} (100%) delete mode 100644 llm_server/workers/app.py rename llm_server/workers/{blocking.py => inferencer.py} (88%) delete mode 100644 llm_server/workers/main.py create mode 100644 llm_server/workers/mainer.py rename llm_server/workers/{recent.py => recenter.py} (100%) create mode 100644 llm_server/workers/threader.py delete mode 100644 llm_server/workers/threads.py rename gunicorn.py => other/gunicorn.py (60%) diff --git a/daemon.py b/daemon.py index 93e8d34..82635f0 100644 --- a/daemon.py +++ b/daemon.py @@ -1,22 +1,12 @@ -import time - -from llm_server.custom_redis import redis - -try: - import gevent.monkey - - gevent.monkey.patch_all() -except ImportError: - pass - import os import sys +import time from pathlib import Path from llm_server.config.load import load_config +from llm_server.custom_redis import redis from llm_server.database.create import create_db - -from llm_server.workers.app import start_background +from llm_server.workers.threader import start_background script_path = os.path.dirname(os.path.realpath(__file__)) config_path_environ = os.getenv("CONFIG_PATH") @@ -29,7 +19,7 @@ if __name__ == "__main__": flushed_keys = redis.flush() print('Flushed', len(flushed_keys), 'keys from Redis.') - success, config, msg = load_config(config_path, script_path) + success, config, msg = load_config(config_path) if not success: print('Failed to load config:', msg) sys.exit(1) diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py new file mode 100644 index 0000000..7b28e86 --- /dev/null +++ b/llm_server/cluster/backend.py @@ -0,0 +1,71 @@ +from llm_server.cluster.redis_config_cache import RedisClusterStore +from llm_server.cluster.redis_cycle import redis_cycle +from llm_server.cluster.stores import redis_running_models +from llm_server.llm.info import get_running_model + + +def test_backend(backend_url: str, mode: str): + running_model, err = get_running_model(backend_url, mode) + if not running_model: + return False + return True + + +def get_backends(): + cluster_config = RedisClusterStore('cluster_config') + backends = cluster_config.all() + result = {} + for k, v in backends.items(): + b = cluster_config.get_backend(k) + status = b['online'] + priority = b['priority'] + result[k] = {'status': status, 'priority': priority} + online_backends = sorted( + ((url, info) for url, info in backends.items() if info['online']), + key=lambda kv: -kv[1]['priority'], + reverse=True + ) + offline_backends = sorted( + ((url, info) for url, info in backends.items() if not info['online']), + key=lambda kv: -kv[1]['priority'], + reverse=True + ) + return [url for url, info in online_backends], [url for url, info in offline_backends] + + +def get_a_cluster_backend(): + """ + Get a backend from Redis. If there are no online backends, return None. + """ + online, offline = get_backends() + cycled = redis_cycle('backend_cycler') + c = cycled.copy() + for i in range(len(cycled)): + if cycled[i] in offline: + del c[c.index(cycled[i])] + if len(c): + return c[0] + else: + return None + + +def get_backends_from_model(model_name: str): + cluster_config = RedisClusterStore('cluster_config') + a = cluster_config.all() + matches = [] + for k, v in a.items(): + if v['online'] and v['running_model'] == model_name: + matches.append(k) + return matches + + +def purge_backend_from_running_models(backend_url: str): + keys = redis_running_models.keys() + pipeline = redis_running_models.pipeline() + for model in keys: + pipeline.srem(model, backend_url) + pipeline.execute() + + +def is_valid_model(model_name: str): + return redis_running_models.exists(model_name) diff --git a/llm_server/cluster/cluster_config.py b/llm_server/cluster/cluster_config.py new file mode 100644 index 0000000..14a6cb0 --- /dev/null +++ b/llm_server/cluster/cluster_config.py @@ -0,0 +1,3 @@ +from llm_server.cluster.redis_config_cache import RedisClusterStore + +cluster_config = RedisClusterStore('cluster_config') diff --git a/llm_server/cluster/datastore.py b/llm_server/cluster/datastore.py deleted file mode 100644 index e69de29..0000000 diff --git a/llm_server/cluster/funcs/__init__.py b/llm_server/cluster/funcs/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/llm_server/cluster/funcs/backend.py b/llm_server/cluster/funcs/backend.py deleted file mode 100644 index 5b7b535..0000000 --- a/llm_server/cluster/funcs/backend.py +++ /dev/null @@ -1,26 +0,0 @@ -from llm_server.cluster.redis_config_cache import RedisClusterStore -from llm_server.llm.info import get_running_model - - -def test_backend(backend_url: str): - running_model, err = get_running_model(backend_url) - if not running_model: - return False - return True - - -def get_best_backends(): - cluster_config = RedisClusterStore('cluster_config') - backends = cluster_config.all() - result = {} - for k, v in backends.items(): - b = cluster_config.get_backend(k) - status = b['online'] - priority = b['priority'] - result[k] = {'status': status, 'priority': priority} - online_backends = sorted( - ((url, info) for url, info in backends.items() if info['online']), - key=lambda kv: kv[1]['priority'], - reverse=True - ) - return [url for url, info in online_backends] diff --git a/llm_server/cluster/redis_config_cache.py b/llm_server/cluster/redis_config_cache.py index 00a6a02..ebb6099 100644 --- a/llm_server/cluster/redis_config_cache.py +++ b/llm_server/cluster/redis_config_cache.py @@ -1,3 +1,4 @@ +import hashlib import pickle from llm_server.custom_redis import RedisCustom @@ -13,14 +14,17 @@ class RedisClusterStore: def load(self, config: dict): for k, v in config.items(): - self.set_backend(k, v) + self.add_backend(k, v) - def set_backend(self, name: str, values: dict): + def add_backend(self, name: str, values: dict): self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()}) self.set_backend_value(name, 'online', False) + h = hashlib.sha256(name.encode('utf-8')).hexdigest() + self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}') - def set_backend_value(self, key: str, name: str, value): - self.config_redis.hset(key, name, pickle.dumps(value)) + def set_backend_value(self, backend: str, key: str, value): + # By storing the value as a pickle we don't have to cast anything when getting the value from Redis. + self.config_redis.hset(backend, key, pickle.dumps(value)) def get_backend(self, name: str): r = self.config_redis.hgetall(name) diff --git a/llm_server/cluster/redis_cycle.py b/llm_server/cluster/redis_cycle.py new file mode 100644 index 0000000..87893ba --- /dev/null +++ b/llm_server/cluster/redis_cycle.py @@ -0,0 +1,21 @@ +import redis + +r = redis.Redis(host='localhost', port=6379, db=9) + + +def redis_cycle(list_name): + while True: + pipe = r.pipeline() + pipe.lpop(list_name) + popped_element = pipe.execute()[0] + if popped_element is None: + return None + r.rpush(list_name, popped_element) + new_list = r.lrange(list_name, 0, -1) + return [x.decode('utf-8') for x in new_list] + + +def load_backend_cycle(list_name: str, elements: list): + r.delete(list_name) + for element in elements: + r.rpush(list_name, element) diff --git a/llm_server/cluster/stores.py b/llm_server/cluster/stores.py new file mode 100644 index 0000000..c0cbdcc --- /dev/null +++ b/llm_server/cluster/stores.py @@ -0,0 +1,3 @@ +from llm_server.custom_redis import RedisCustom + +redis_running_models = RedisCustom('running_models') diff --git a/llm_server/cluster/worker.py b/llm_server/cluster/worker.py index 4aaaf6a..bee280a 100644 --- a/llm_server/cluster/worker.py +++ b/llm_server/cluster/worker.py @@ -1,10 +1,10 @@ -import time +from datetime import datetime from threading import Thread -from llm_server.cluster.funcs.backend import test_backend -from llm_server.cluster.redis_config_cache import RedisClusterStore - -cluster_config = RedisClusterStore('cluster_config') +from llm_server.cluster.backend import purge_backend_from_running_models, test_backend +from llm_server.cluster.cluster_config import cluster_config +from llm_server.cluster.stores import redis_running_models +from llm_server.llm.info import get_running_model def cluster_worker(): @@ -16,10 +16,16 @@ def cluster_worker(): threads.append(thread) for thread in threads: thread.join() - time.sleep(10) def check_backend(n, v): # Check if backends are online - online = test_backend(v['backend_url']) + # TODO: also have test_backend() get the uptime + online = test_backend(v['backend_url'], v['mode']) + if online: + running_model, err = get_running_model(v['backend_url'], v['mode']) + if not err: + cluster_config.set_backend_value(n, 'running_model', running_model) + purge_backend_from_running_models(n) + redis_running_models.sadd(running_model, n) cluster_config.set_backend_value(n, 'online', online) diff --git a/llm_server/config/config.py b/llm_server/config/config.py index 59568d7..b98ea49 100644 --- a/llm_server/config/config.py +++ b/llm_server/config/config.py @@ -32,7 +32,8 @@ config_default_vars = { 'openai_org_name': 'OpenAI', 'openai_silent_trim': False, 'openai_moderation_enabled': True, - 'netdata_root': None + 'netdata_root': None, + 'show_backends': True, } config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 82afe81..09fb127 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -52,6 +52,7 @@ def load_config(config_path): opts.openai_org_name = config['openai_org_name'] opts.openai_silent_trim = config['openai_silent_trim'] opts.openai_moderation_enabled = config['openai_moderation_enabled'] + opts.show_backends = config['show_backends'] if opts.openai_expose_our_model and not opts.openai_api_key: print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') diff --git a/llm_server/custom_redis.py b/llm_server/custom_redis.py index 1b0b5f3..b0db49e 100644 --- a/llm_server/custom_redis.py +++ b/llm_server/custom_redis.py @@ -1,13 +1,13 @@ import pickle import sys import traceback -from typing import Callable, List, Mapping, Union, Optional +from typing import Callable, List, Mapping, Optional, Union import redis as redis_pkg import simplejson as json from flask_caching import Cache from redis import Redis -from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT, PatternT +from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) @@ -35,12 +35,12 @@ class RedisCustom: def set(self, key, value, ex: Union[ExpiryT, None] = None): return self.redis.set(self._key(key), value, ex=ex) - def get(self, key, dtype=None, default=None): - """ - :param key: - :param dtype: convert to this type - :return: - """ + def get(self, key, default=None, dtype=None): + # TODO: use pickle + import inspect + if inspect.isclass(default): + raise Exception + d = self.redis.get(self._key(key)) if dtype and d: try: @@ -153,11 +153,23 @@ class RedisCustom: keys = [] for key in raw_keys: p = key.decode('utf-8').split(':') - if len(p) > 2: + if len(p) >= 2: + # Delete prefix del p[0] - keys.append(':'.join(p)) + k = ':'.join(p) + if k != '____': + keys.append(k) return keys + def pipeline(self, transaction=True, shard_hint=None): + return self.redis.pipeline(transaction, shard_hint) + + def exists(self, *names: KeyT): + n = [] + for name in names: + n.append(self._key(name)) + return self.redis.exists(*n) + def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): return self.set(key, json.dumps(dict_value), ex=ex) @@ -174,7 +186,7 @@ class RedisCustom: def getp(self, name: str): r = self.redis.get(name) if r: - return pickle.load(r) + return pickle.loads(r) return r def flush(self): diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 3779c83..bf5f537 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -1,60 +1,69 @@ import json import time import traceback +from threading import Thread import llm_server from llm_server import opts +from llm_server.custom_redis import redis from llm_server.database.conn import database from llm_server.llm.vllm import tokenize -from llm_server.custom_redis import redis -def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): - if isinstance(response, dict) and response.get('results'): - response = response['results'][0]['text'] - try: - j = json.loads(response) - if j.get('results'): - response = j['results'][0]['text'] - except: - pass +def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens: int = None, is_error: bool = False): + def background_task(): + nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens, is_error + # Try not to shove JSON into the database. + if isinstance(response, dict) and response.get('results'): + response = response['results'][0]['text'] + try: + j = json.loads(response) + if j.get('results'): + response = j['results'][0]['text'] + except: + pass - prompt_tokens = llm_server.llm.get_token_count(prompt) - if not is_error: - if not response_tokens: - response_tokens = llm_server.llm.get_token_count(response) - else: - response_tokens = None + prompt_tokens = llm_server.llm.get_token_count(prompt) + if not is_error: + if not response_tokens: + response_tokens = llm_server.llm.get_token_count(response) + else: + response_tokens = None - # Sometimes we may want to insert null into the DB, but - # usually we want to insert a float. - if gen_time: - gen_time = round(gen_time, 3) - if is_error: - gen_time = None + # Sometimes we may want to insert null into the DB, but + # usually we want to insert a float. + if gen_time: + gen_time = round(gen_time, 3) + if is_error: + gen_time = None - if not opts.log_prompts: - prompt = None + if not opts.log_prompts: + prompt = None - if not opts.log_prompts and not is_error: - # TODO: test and verify this works as expected - response = None + if not opts.log_prompts and not is_error: + # TODO: test and verify this works as expected + response = None - if token: - increment_token_uses(token) + if token: + increment_token_uses(token) - running_model = redis.get('running_model', str, 'ERROR') - timestamp = int(time.time()) - cursor = database.cursor() - try: - cursor.execute(""" - INSERT INTO prompts - (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, - (ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) - finally: - cursor.close() + running_model = redis.get('running_model', str, 'ERROR') + timestamp = int(time.time()) + cursor = database.cursor() + try: + cursor.execute(""" + INSERT INTO prompts + (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + (ip, token, running_model, opts.mode, cluster_backend, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) + finally: + cursor.close() + + # TODO: use async/await instead of threads + thread = Thread(target=background_task) + thread.start() + thread.join() def is_valid_api_key(api_key): diff --git a/llm_server/helpers.py b/llm_server/helpers.py index d6eb7d9..9fc7274 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -60,7 +60,7 @@ def round_up_base(n, base): def auto_set_base_client_api(request): - http_host = redis.get('http_host', str) + http_host = redis.get('http_host', dtype=str) host = request.headers.get("Host") if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host): # If the current http_host is not an IP, don't do anything. diff --git a/llm_server/integer.py b/llm_server/integer.py deleted file mode 100644 index 1410dd1..0000000 --- a/llm_server/integer.py +++ /dev/null @@ -1,12 +0,0 @@ -import threading - - -class ThreadSafeInteger: - def __init__(self, value=0): - self.value = value - self._value_lock = threading.Lock() - - def increment(self): - with self._value_lock: - self.value += 1 - return self.value diff --git a/llm_server/llm/__init__.py b/llm_server/llm/__init__.py index a08b25e..6e39b42 100644 --- a/llm_server/llm/__init__.py +++ b/llm_server/llm/__init__.py @@ -3,7 +3,7 @@ from llm_server.custom_redis import redis def get_token_count(prompt: str): - backend_mode = redis.get('backend_mode', str) + backend_mode = redis.get('backend_mode', dtype=str) if backend_mode == 'vllm': return vllm.tokenize(prompt) elif backend_mode == 'ooba': diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py index 5dd2093..42c3bb7 100644 --- a/llm_server/llm/generator.py +++ b/llm_server/llm/generator.py @@ -1,14 +1,14 @@ from llm_server import opts -def generator(request_json_body): +def generator(request_json_body, cluster_backend): if opts.mode == 'oobabooga': # from .oobabooga.generate import generate # return generate(request_json_body) raise NotImplementedError elif opts.mode == 'vllm': from .vllm.generate import generate - r = generate(request_json_body) + r = generate(request_json_body, cluster_backend) return r else: raise Exception diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py index bedf3eb..117da3f 100644 --- a/llm_server/llm/info.py +++ b/llm_server/llm/info.py @@ -3,19 +3,15 @@ import requests from llm_server import opts -def get_running_model(backend_url: str): - # TODO: remove this once we go to Redis - if not backend_url: - backend_url = opts.backend_url - - if opts.mode == 'oobabooga': +def get_running_model(backend_url: str, mode: str): + if mode == 'ooba': try: backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() return r_json['result'], None except Exception as e: return False, e - elif opts.mode == 'vllm': + elif mode == 'vllm': try: backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index 153f66d..e8268b1 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -40,6 +40,6 @@ class LLMBackend: def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: prompt_len = get_token_count(prompt) if prompt_len > opts.context_size - 10: - model_name = redis.get('running_model', str, 'NO MODEL ERROR') + model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str) return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size' return True, None diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index 8f1898e..62e0ed8 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -34,7 +34,7 @@ def build_openai_response(prompt, response, model=None): # TODO: async/await prompt_tokens = llm_server.llm.get_token_count(prompt) response_tokens = llm_server.llm.get_token_count(response) - running_model = redis.get('running_model', str, 'ERROR') + running_model = redis.get('running_model', 'ERROR', dtype=str) response = make_response(jsonify({ "id": f"chatcmpl-{generate_oai_string(30)}", @@ -57,7 +57,7 @@ def build_openai_response(prompt, response, model=None): } }), 200) - stats = redis.get('proxy_stats', dict) + stats = redis.get('proxy_stats', dtype=dict) if stats: response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] return response diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 308b1de..caac445 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -49,7 +49,7 @@ def transform_to_text(json_request, api_response): prompt_tokens = len(llm_server.llm.get_token_count(prompt)) completion_tokens = len(llm_server.llm.get_token_count(text)) - running_model = redis.get('running_model', str, 'ERROR') + running_model = redis.get('running_model', 'ERROR', dtype=str) # https://platform.openai.com/docs/api-reference/making-requests?lang=python return { @@ -82,9 +82,9 @@ def transform_prompt_to_text(prompt: list): return text.strip('\n') -def handle_blocking_request(json_data: dict): +def handle_blocking_request(json_data: dict, cluster_backend): try: - r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) except requests.exceptions.ReadTimeout: print(f'Failed to reach VLLM inference endpoint - request to backend timed out') return False, None, 'Request to backend timed out' @@ -97,11 +97,11 @@ def handle_blocking_request(json_data: dict): return True, r, None -def generate(json_data: dict): +def generate(json_data: dict, cluster_backend): if json_data.get('stream'): try: - return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) except Exception as e: print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') else: - return handle_blocking_request(json_data) + return handle_blocking_request(json_data, cluster_backend) diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index e5b0fad..3db99d9 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -19,16 +19,8 @@ class VLLMBackend(LLMBackend): # Failsafe backend_response = '' - r_url = request.url - - def background_task(): - log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url, - response_tokens=response_json_body.get('details', {}).get('generated_tokens')) - - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task) - thread.start() - thread.join() + log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, + response_tokens=response_json_body.get('details', {}).get('generated_tokens')) return jsonify({'results': [{'text': backend_response}]}), 200 diff --git a/llm_server/opts.py b/llm_server/opts.py index 5eec1fa..0d13979 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -2,7 +2,6 @@ # TODO: rewrite the config system so I don't have to add every single config default here -running_model = 'ERROR' concurrent_gens = 3 mode = 'oobabooga' backend_url = None @@ -38,3 +37,4 @@ openai_org_name = 'OpenAI' openai_silent_trim = False openai_moderation_enabled = True cluster = {} +show_backends = True diff --git a/llm_server/pre_fork.py b/llm_server/pre_fork.py index 21da08e..900210c 100644 --- a/llm_server/pre_fork.py +++ b/llm_server/pre_fork.py @@ -7,7 +7,7 @@ from llm_server.routes.v1.generate_stats import generate_stats def server_startup(s): - if not redis.get('daemon_started', bool): + if not redis.get('daemon_started', dtype=bool): print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') sys.exit(1) diff --git a/llm_server/routes/helpers/client.py b/llm_server/routes/helpers/client.py index d97e9c5..a914362 100644 --- a/llm_server/routes/helpers/client.py +++ b/llm_server/routes/helpers/client.py @@ -1,10 +1,14 @@ +from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis -def format_sillytavern_err(msg: str, level: str = 'info'): - http_host = redis.get('http_host', str) +def format_sillytavern_err(msg: str, backend_url: str, level: str = 'info'): + cluster_backend_hash = cluster_config.get_backend_handler(backend_url)['hash'] + http_host = redis.get('http_host', dtype=str) return f"""``` === MESSAGE FROM LLM MIDDLEWARE AT {http_host} === -> {level.upper()} <- {msg} + +BACKEND HASH: {cluster_backend_hash} ```""" diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index d6b02e2..8e0036c 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler): msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' backend_response = self.handle_error(msg) if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.cluster_backend, is_error=True) return backend_response[0], 200 # We only return the response from handle_error(), not the error code def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: @@ -40,7 +40,7 @@ class OobaRequestHandler(RequestHandler): # TODO: how to format this response_msg = error_msg else: - response_msg = format_sillytavern_err(error_msg, error_type) + response_msg = format_sillytavern_err(error_msg, error_type, self.cluster_backend) return jsonify({ 'results': [{'text': response_msg}] diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index a289c78..b3159a5 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -5,11 +5,12 @@ import traceback from flask import Response, jsonify, request -from . import openai_bp from llm_server.custom_redis import redis +from . import openai_bp from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ... import opts +from ...cluster.backend import get_a_cluster_backend from ...database.database import log_prompt from ...llm.generator import generator from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt @@ -48,10 +49,11 @@ def openai_chat_completions(): 'stream': True, } try: - response = generator(msg_to_backend) + cluster_backend = get_a_cluster_backend() + response = generator(msg_to_backend, cluster_backend) r_headers = dict(request.headers) r_url = request.url - model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model') + model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) def generate(): @@ -94,7 +96,7 @@ def openai_chat_completions(): def background_task(): generated_tokens = tokenize(generated_text) - log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) + log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=generated_tokens) # TODO: use async/await instead of threads thread = threading.Thread(target=background_task) diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 05ac7a5..8950927 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -29,7 +29,7 @@ def openai_completions(): # TODO: async/await prompt_tokens = get_token_count(request_json_body['prompt']) response_tokens = get_token_count(output) - running_model = redis.get('running_model', str, 'ERROR') + running_model = redis.get('running_model', 'ERROR', dtype=str) response = make_response(jsonify({ "id": f"cmpl-{generate_oai_string(30)}", @@ -51,7 +51,7 @@ def openai_completions(): } }), 200) - stats = redis.get('proxy_stats', dict) + stats = redis.get('proxy_stats', dtype=dict) if stats: response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] return response diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py index 4f732e6..657f084 100644 --- a/llm_server/routes/openai/models.py +++ b/llm_server/routes/openai/models.py @@ -7,6 +7,7 @@ from . import openai_bp from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis from ..stats import server_start_time from ... import opts +from ...cluster.backend import get_a_cluster_backend from ...helpers import jsonify_pretty from ...llm.info import get_running_model @@ -22,7 +23,7 @@ def openai_list_models(): 'type': error.__class__.__name__ }), 500 # return 500 so Cloudflare doesn't intercept us else: - running_model = redis.get('running_model', str, 'ERROR') + running_model = redis.get('running_model', 'ERROR', dtype=str) oai = fetch_openai_models() r = [] if opts.openai_expose_our_model: diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 8d85319..09ed06c 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -93,6 +93,6 @@ def incr_active_workers(): def decr_active_workers(): redis.decr('active_gen_workers') - new_count = redis.get('active_gen_workers', int, 0) + new_count = redis.get('active_gen_workers', 0, dtype=int) if new_count < 0: redis.set('active_gen_workers', 0) diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index bb64859..ecae085 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -5,13 +5,15 @@ import flask from flask import Response, request from llm_server import opts +from llm_server.cluster.backend import get_a_cluster_backend +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis from llm_server.database.conn import database from llm_server.database.database import log_prompt from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token -from llm_server.custom_redis import redis from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue @@ -35,7 +37,9 @@ class RequestHandler: self.client_ip = self.get_client_ip() self.token = self.get_auth_token() self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit() - self.backend = get_backend() + self.cluster_backend = get_a_cluster_backend() + self.cluster_backend_info = cluster_config.get_backend(self.cluster_backend) + self.backend = get_backend_handler(self.cluster_backend) self.parameters = None self.used = False redis.zadd('recent_prompters', {self.client_ip: time.time()}) @@ -119,7 +123,7 @@ class RequestHandler: backend_response = self.handle_error(combined_error_message, 'Validation Error') if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.cluster_backend, is_error=True) return False, backend_response return True, (None, 0) @@ -131,7 +135,7 @@ class RequestHandler: request_valid, invalid_response = self.validate_request(prompt, do_log=True) if not request_valid: return (False, None, None, 0), invalid_response - event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority) + event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.cluster_backend), self.token_priority) else: event = None @@ -160,7 +164,7 @@ class RequestHandler: else: error_msg = error_msg.strip('.') + '.' backend_response = self.handle_error(error_msg) - log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) + log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True) return (False, None, None, 0), backend_response # =============================================== @@ -180,7 +184,7 @@ class RequestHandler: if return_json_err: error_msg = 'The backend did not return valid JSON.' backend_response = self.handle_error(error_msg) - log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) + log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True) return (False, None, None, 0), backend_response # =============================================== @@ -214,10 +218,10 @@ class RequestHandler: raise NotImplementedError -def get_backend(): - if opts.mode == 'oobabooga': +def get_backend_handler(mode): + if mode == 'oobabooga': return OobaboogaBackend() - elif opts.mode == 'vllm': + elif mode == 'vllm': return VLLMBackend() else: raise Exception diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index a0846d9..b4dea54 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -2,32 +2,9 @@ from datetime import datetime from llm_server.custom_redis import redis -# proompters_5_min = 0 -# concurrent_semaphore = Semaphore(concurrent_gens) - server_start_time = datetime.now() -# TODO: do I need this? -# def elapsed_times_cleanup(): -# global wait_in_queue_elapsed -# while True: -# current_time = time.time() -# with wait_in_queue_elapsed_lock: -# global wait_in_queue_elapsed -# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60] -# time.sleep(1) - - -def calculate_avg_gen_time(): - # Get the average generation time from Redis - average_generation_time = redis.get('average_generation_time') - if average_generation_time is None: - return 0 - else: - return float(average_generation_time) - - def get_total_proompts(): count = redis.get('proompts') if count is None: diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index b2dd527..66dd316 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -2,11 +2,12 @@ import time from datetime import datetime from llm_server import opts +from llm_server.cluster.backend import get_a_cluster_backend, test_backend +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis from llm_server.database.database import get_distinct_ips_24h, sum_column from llm_server.helpers import deep_sort, round_up_base from llm_server.llm.info import get_running_model -from llm_server.netdata import get_power_states -from llm_server.custom_redis import redis from llm_server.routes.queue import priority_queue from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time @@ -33,52 +34,43 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act return gen_time_calc -# TODO: have routes/__init__.py point to the latest API version generate_stats() - def generate_stats(regen: bool = False): if not regen: - c = redis.get('proxy_stats', dict) + c = redis.get('proxy_stats', dtype=dict) if c: return c - model_name, error = get_running_model() # will return False when the fetch fails - if isinstance(model_name, bool): - online = False - else: - online = True - redis.set('running_model', model_name) + default_backend_url = get_a_cluster_backend() + default_backend_info = cluster_config.get_backend(default_backend_url) + if not default_backend_info.get('mode'): + # TODO: remove + print('DAEMON NOT FINISHED STARTING') + return + base_client_api = redis.get('base_client_api', dtype=str) + proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf')) + average_generation_elapsed_sec = redis.get('average_generation_elapsed_sec', 0) - # t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change - # if len(t) == 0: - # estimated_wait = 0 - # else: - # waits = [elapsed for end, elapsed in t] - # estimated_wait = int(sum(waits) / len(waits)) + online = test_backend(default_backend_url, default_backend_info['mode']) + if online: + running_model, err = get_running_model(default_backend_url, default_backend_info['mode']) + cluster_config.set_backend_value(default_backend_url, 'running_model', running_model) + else: + running_model = None active_gen_workers = get_active_gen_workers() proompters_in_queue = len(priority_queue) - # This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM. + # This is so wildly inaccurate it's disabled. # estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0) - average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0) - estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers) - - if opts.netdata_root: - netdata_stats = {} - power_states = get_power_states() - for gpu, power_state in power_states.items(): - netdata_stats[gpu] = { - 'power_state': power_state, - # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu'))) - } - else: - netdata_stats = {} - - base_client_api = redis.get('base_client_api', str) - proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf')) + # TODO: make this for the currently selected backend + estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers) output = { + 'default': { + 'model': running_model, + 'backend': default_backend_info['hash'], + }, 'stats': { 'proompters': { '5_min': proompters_5_min, @@ -86,9 +78,10 @@ def generate_stats(regen: bool = False): }, 'proompts_total': get_total_proompts() if opts.show_num_prompts else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None, - 'average_generation_elapsed_sec': int(average_generation_time), + 'average_generation_elapsed_sec': int(average_generation_elapsed_sec), # 'estimated_avg_tps': estimated_avg_tps, 'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None, + 'num_backends': len(cluster_config.all()) if opts.show_backends else None, }, 'online': online, 'endpoints': { @@ -103,10 +96,7 @@ def generate_stats(regen: bool = False): 'timestamp': int(time.time()), 'config': { 'gatekeeper': 'none' if opts.auth_required is False else 'token', - 'context_size': opts.context_size, 'concurrent': opts.concurrent_gens, - 'model': opts.manual_model_name if opts.manual_model_name else model_name, - 'mode': opts.mode, 'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip, }, 'keys': { @@ -114,8 +104,41 @@ def generate_stats(regen: bool = False): 'anthropicKeys': '∞', }, 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, - 'nvidia': netdata_stats } + + if opts.show_backends: + for backend_url, v in cluster_config.all().items(): + backend_info = cluster_config.get_backend(backend_url) + if not backend_info['online']: + continue + + # TODO: have this fetch the data from VLLM which will display GPU utalization + # if opts.netdata_root: + # netdata_stats = {} + # power_states = get_power_states() + # for gpu, power_state in power_states.items(): + # netdata_stats[gpu] = { + # 'power_state': power_state, + # # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu'))) + # } + # else: + # netdata_stats = {} + netdata_stats = {} + + # TODO: use value returned by VLLM backend here + # backend_uptime = int((datetime.now() - backend_info['start_time']).total_seconds()) if opts.show_uptime else None + backend_uptime = -1 + + output['backend_info'][backend_info['hash']] = { + 'uptime': backend_uptime, + # 'context_size': opts.context_size, + 'model': opts.manual_model_name if opts.manual_model_name else backend_info.get('running_model', 'ERROR'), + 'mode': backend_info['mode'], + 'nvidia': netdata_stats + } + else: + output['backend_info'] = {} + result = deep_sort(output) # It may take a bit to get the base client API, so don't cache until then. diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 45fbf12..e3aeeb0 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -1,5 +1,4 @@ import json -import threading import time import traceback from typing import Union @@ -10,10 +9,11 @@ from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts +from ...cluster.backend import get_a_cluster_backend from ...database.database import log_prompt from ...llm.generator import generator from ...llm.vllm import tokenize -from ...stream import sock +from ...sock import sock # TODO: have workers process streaming requests @@ -35,19 +35,13 @@ def stream(ws): log_in_bg(quitting_err_msg, is_error=True) def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None): - - def background_task_exception(): - generated_tokens = tokenize(generated_text_bg) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error) - - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task_exception) - thread.start() - thread.join() + generated_tokens = tokenize(generated_text_bg) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, cluster_backend, response_tokens=generated_tokens, is_error=is_error) if not opts.enable_streaming: return 'Streaming is disabled', 401 + cluster_backend = None r_headers = dict(request.headers) r_url = request.url message_num = 0 @@ -90,14 +84,15 @@ def stream(ws): } # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority) + event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority) if not event: r, _ = handler.handle_ratelimited() err_msg = r.json['results'][0]['text'] send_err_and_quit(err_msg) return try: - response = generator(llm_request) + cluster_backend = get_a_cluster_backend() + response = generator(llm_request, cluster_backend) if not response: error_msg = 'Failed to reach backend while streaming.' print('Streaming failed:', error_msg) @@ -142,7 +137,7 @@ def stream(ws): ws.close() end_time = time.time() elapsed_time = end_time - start_time - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text)) return message_num += 1 @@ -181,5 +176,5 @@ def stream(ws): # The client closed the stream. end_time = time.time() elapsed_time = end_time - start_time - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text)) ws.close() # this is important if we encountered and error and exited early. diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index 7cdbf0f..90778e5 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -2,22 +2,21 @@ import time from flask import jsonify, request +from llm_server.custom_redis import flask_cache from . import bp from ..auth import requires_auth -from llm_server.custom_redis import flask_cache from ... import opts -from ...llm.info import get_running_model - - -# @bp.route('/info', methods=['GET']) -# # @cache.cached(timeout=3600, query_string=True) -# def get_info(): -# # requests.get() -# return 'yes' +from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends_from_model, is_valid_model +from ...cluster.cluster_config import cluster_config @bp.route('/model', methods=['GET']) -def get_model(): +@bp.route('//model', methods=['GET']) +def get_model(model_name=None): + if not model_name: + b = get_a_cluster_backend() + model_name = cluster_config.get_backend(b)['running_model'] + # We will manage caching ourself since we don't want to cache # when the backend is down. Also, Cloudflare won't cache 500 errors. cache_key = 'model_cache::' + request.url @@ -26,16 +25,17 @@ def get_model(): if cached_response: return cached_response - model_name, error = get_running_model() - if not model_name: + if not is_valid_model(model_name): response = jsonify({ - 'code': 502, - 'msg': 'failed to reach backend', - 'type': error.__class__.__name__ - }), 500 # return 500 so Cloudflare doesn't intercept us + 'code': 400, + 'msg': 'Model does not exist.', + }), 400 else: + num_backends = len(get_backends_from_model(model_name)) + response = jsonify({ 'result': opts.manual_model_name if opts.manual_model_name else model_name, + 'model_backend_count': num_backends, 'timestamp': int(time.time()) }), 200 flask_cache.set(cache_key, response, timeout=60) @@ -43,7 +43,11 @@ def get_model(): return response -@bp.route('/backend', methods=['GET']) +@bp.route('/backends', methods=['GET']) @requires_auth def get_backend(): - return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200 + online, offline = get_backends() + result = [] + for i in online + offline: + result.append(cluster_config.get_backend(i)) + return jsonify(result), 200 diff --git a/llm_server/stream.py b/llm_server/sock.py similarity index 100% rename from llm_server/stream.py rename to llm_server/sock.py diff --git a/llm_server/workers/app.py b/llm_server/workers/app.py deleted file mode 100644 index fda6fb3..0000000 --- a/llm_server/workers/app.py +++ /dev/null @@ -1,35 +0,0 @@ -from threading import Thread - -from .blocking import start_workers -from .main import main_background_thread -from .moderator import start_moderation_workers -from .printer import console_printer -from .recent import recent_prompters_thread -from .threads import cache_stats -from .. import opts - - -def start_background(): - start_workers(opts.concurrent_gens) - - t = Thread(target=main_background_thread) - t.daemon = True - t.start() - print('Started the main background thread.') - - start_moderation_workers(opts.openai_moderation_workers) - - t = Thread(target=cache_stats) - t.daemon = True - t.start() - print('Started the stats cacher.') - - t = Thread(target=recent_prompters_thread) - t.daemon = True - t.start() - print('Started the recent proompters thread.') - - t = Thread(target=console_printer) - t.daemon = True - t.start() - print('Started the console printer.') diff --git a/llm_server/workers/blocking.py b/llm_server/workers/inferencer.py similarity index 88% rename from llm_server/workers/blocking.py rename to llm_server/workers/inferencer.py index dcf0047..626e34b 100644 --- a/llm_server/workers/blocking.py +++ b/llm_server/workers/inferencer.py @@ -2,15 +2,15 @@ import threading import time from llm_server import opts -from llm_server.llm.generator import generator from llm_server.custom_redis import redis +from llm_server.llm.generator import generator from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue def worker(): while True: need_to_wait() - (request_json_body, client_ip, token, parameters), event_id = priority_queue.get() + (request_json_body, client_ip, token, parameters, cluster_backend), event_id = priority_queue.get() need_to_wait() increment_ip_count(client_ip, 'processing_ips') @@ -22,7 +22,7 @@ def worker(): continue try: - success, response, error_msg = generator(request_json_body) + success, response, error_msg = generator(request_json_body, cluster_backend) event = DataEvent(event_id) event.set((success, response, error_msg)) finally: @@ -42,7 +42,7 @@ def start_workers(num_workers: int): def need_to_wait(): # We need to check the number of active workers since the streaming endpoint may be doing something. - active_workers = redis.get('active_gen_workers', int, 0) + active_workers = redis.get('active_gen_workers', 0, dtype=int) s = time.time() while active_workers >= opts.concurrent_gens: time.sleep(0.01) diff --git a/llm_server/workers/main.py b/llm_server/workers/main.py deleted file mode 100644 index f592c5e..0000000 --- a/llm_server/workers/main.py +++ /dev/null @@ -1,55 +0,0 @@ -import time - -from llm_server import opts -from llm_server.database.database import weighted_average_column_for_model -from llm_server.llm.info import get_running_model -from llm_server.custom_redis import redis - - -def main_background_thread(): - redis.set('average_generation_elapsed_sec', 0) - redis.set('estimated_avg_tps', 0) - redis.set('average_output_tokens', 0) - redis.set('backend_online', 0) - redis.set_dict('backend_info', {}) - - while True: - # TODO: unify this - if opts.mode == 'oobabooga': - running_model, err = get_running_model() - if err: - print(err) - redis.set('backend_online', 0) - else: - redis.set('running_model', running_model) - redis.set('backend_online', 1) - elif opts.mode == 'vllm': - running_model, err = get_running_model() - if err: - print(err) - redis.set('backend_online', 0) - else: - redis.set('running_model', running_model) - redis.set('backend_online', 1) - else: - raise Exception - - # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 - # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. - average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 - if average_generation_elapsed_sec: # returns None on exception - redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) - - # overall = average_column_for_model('prompts', 'generation_time', opts.running_model) - # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') - - average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 - if average_generation_elapsed_sec: - redis.set('average_output_tokens', average_output_tokens) - - # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) - # print(f'Weighted: {average_output_tokens}, overall: {overall}') - - estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero - redis.set('estimated_avg_tps', estimated_avg_tps) - time.sleep(60) diff --git a/llm_server/workers/mainer.py b/llm_server/workers/mainer.py new file mode 100644 index 0000000..447046f --- /dev/null +++ b/llm_server/workers/mainer.py @@ -0,0 +1,56 @@ +import time + +from llm_server import opts +from llm_server.cluster.backend import get_a_cluster_backend, get_backends +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis +from llm_server.database.database import weighted_average_column_for_model +from llm_server.llm.info import get_running_model + + +def main_background_thread(): + while True: + online, offline = get_backends() + for backend_url in online: + backend_info = cluster_config.get_backend(backend_url) + backend_mode = backend_info['mode'] + running_model, err = get_running_model(backend_url, backend_mode) + if err: + continue + + average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode) + if average_generation_elapsed_sec: # returns None on exception + cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec) + if average_output_tokens: + cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens) + if average_generation_elapsed_sec and average_output_tokens: + cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps) + + default_backend_url = get_a_cluster_backend() + default_backend_info = cluster_config.get_backend(default_backend_url) + default_backend_mode = default_backend_info['mode'] + default_running_model, err = get_running_model(default_backend_url, default_backend_mode) + if err: + continue + + default_average_generation_elapsed_sec, default_average_output_tokens, default_estimated_avg_tps = calc_stats_for_backend(default_running_model, default_running_model, default_backend_mode) + if default_average_generation_elapsed_sec: + redis.set('average_generation_elapsed_sec', default_average_generation_elapsed_sec) + if default_average_output_tokens: + redis.set('average_output_tokens', default_average_output_tokens) + if default_average_generation_elapsed_sec and default_average_output_tokens: + redis.set('estimated_avg_tps', default_estimated_avg_tps) + time.sleep(30) + + +def calc_stats_for_backend(backend_url, running_model, backend_mode): + # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 + # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. + average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', + running_model, backend_mode, backend_url, exclude_zeros=True, + include_system_tokens=opts.include_system_tokens_in_stats) or 0 + average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', + running_model, backend_mode, backend_url, exclude_zeros=True, + include_system_tokens=opts.include_system_tokens_in_stats) or 0 + estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero + return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps diff --git a/llm_server/workers/recent.py b/llm_server/workers/recenter.py similarity index 100% rename from llm_server/workers/recent.py rename to llm_server/workers/recenter.py diff --git a/llm_server/workers/threader.py b/llm_server/workers/threader.py new file mode 100644 index 0000000..83bac2d --- /dev/null +++ b/llm_server/workers/threader.py @@ -0,0 +1,50 @@ +import time +from threading import Thread + +from llm_server import opts +from llm_server.cluster.stores import redis_running_models +from llm_server.cluster.worker import cluster_worker +from llm_server.routes.v1.generate_stats import generate_stats +from llm_server.workers.inferencer import start_workers +from llm_server.workers.mainer import main_background_thread +from llm_server.workers.moderator import start_moderation_workers +from llm_server.workers.printer import console_printer +from llm_server.workers.recenter import recent_prompters_thread + + +def cache_stats(): + while True: + generate_stats(regen=True) + time.sleep(1) + + +def start_background(): + start_workers(opts.concurrent_gens) + + t = Thread(target=main_background_thread) + t.daemon = True + t.start() + print('Started the main background thread.') + + start_moderation_workers(opts.openai_moderation_workers) + + t = Thread(target=cache_stats) + t.daemon = True + t.start() + print('Started the stats cacher.') + + t = Thread(target=recent_prompters_thread) + t.daemon = True + t.start() + print('Started the recent proompters thread.') + + t = Thread(target=console_printer) + t.daemon = True + t.start() + print('Started the console printer.') + + redis_running_models.flush() + t = Thread(target=cluster_worker) + t.daemon = True + t.start() + print('Started the cluster worker.') diff --git a/llm_server/workers/threads.py b/llm_server/workers/threads.py deleted file mode 100644 index d1c5183..0000000 --- a/llm_server/workers/threads.py +++ /dev/null @@ -1,9 +0,0 @@ -import time - -from llm_server.routes.v1.generate_stats import generate_stats - - -def cache_stats(): - while True: - generate_stats(regen=True) - time.sleep(5) diff --git a/gunicorn.py b/other/gunicorn.py similarity index 60% rename from gunicorn.py rename to other/gunicorn.py index 30f9274..099e9ce 100644 --- a/gunicorn.py +++ b/other/gunicorn.py @@ -1,3 +1,8 @@ +""" +This file is used to run certain tasks when the HTTP server starts. +It's located here so it doesn't get imported with daemon.py +""" + try: import gevent.monkey diff --git a/server.py b/server.py index 3c334bc..0214b49 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,4 @@ -from llm_server.config.config import mode_ui_names +from llm_server.cluster.cluster_config import cluster_config try: import gevent.monkey @@ -7,8 +7,6 @@ try: except ImportError: pass -from llm_server.pre_fork import server_startup -from llm_server.config.load import load_config, parse_backends import os import sys from pathlib import Path @@ -16,14 +14,17 @@ from pathlib import Path import simplejson as json from flask import Flask, jsonify, render_template, request -import llm_server +from llm_server.cluster.backend import get_a_cluster_backend, get_backends +from llm_server.cluster.redis_cycle import load_backend_cycle +from llm_server.config.config import mode_ui_names +from llm_server.config.load import load_config, parse_backends from llm_server.database.conn import database from llm_server.database.create import create_db -from llm_server.llm import get_token_count +from llm_server.pre_fork import server_startup from llm_server.routes.openai import openai_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp -from llm_server.stream import init_socketio +from llm_server.sock import init_socketio # TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. # TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail @@ -37,6 +38,8 @@ from llm_server.stream import init_socketio # TODO: use coloredlogs # TODO: need to update opts. for workers # TODO: add a healthcheck to VLLM +# TODO: allow choosing the model by the URL path +# TODO: have VLLM report context size, uptime # Lower priority # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens @@ -64,7 +67,7 @@ import config from llm_server import opts from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info -from llm_server.custom_redis import RedisCustom, flask_cache +from llm_server.custom_redis import flask_cache from llm_server.llm import redis from llm_server.routes.stats import get_active_gen_workers from llm_server.routes.v1.generate_stats import generate_stats @@ -83,20 +86,18 @@ if config_path_environ: else: config_path = Path(script_path, 'config', 'config.yml') -success, config, msg = load_config(config_path, script_path) +success, config, msg = load_config(config_path) if not success: print('Failed to load config:', msg) sys.exit(1) database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) create_db() -llm_server.llm.redis = RedisCustom('local_llm') -create_db() -x = parse_backends(config) -print(x) - -# print(app.url_map) +cluster_config.clear() +cluster_config.load(parse_backends(config)) +on, off = get_backends() +load_backend_cycle('backend_cycler', on + off) @app.route('/') @@ -104,12 +105,18 @@ print(x) @app.route('/api/openai') @flask_cache.cached(timeout=10) def home(): - stats = generate_stats() + # Use the default backend + backend_url = get_a_cluster_backend() + if backend_url: + backend_info = cluster_config.get_backend(backend_url) + stats = generate_stats(backend_url) + else: + backend_info = stats = None if not stats['online']: running_model = estimated_wait_sec = 'offline' else: - running_model = redis.get('running_model', str, 'ERROR') + running_model = backend_info['running_model'] active_gen_workers = get_active_gen_workers() if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: @@ -130,10 +137,16 @@ def home(): info_html = '' mode_info = '' - if opts.mode == 'vllm': + using_vllm = False + for k, v in cluster_config.all().items(): + if v['mode'] == vllm: + using_vllm = True + break + + if using_vllm == 'vllm': mode_info = vllm_info - base_client_api = redis.get('base_client_api', str) + base_client_api = redis.get('base_client_api', dtype=str) return render_template('home.html', llm_middleware_name=opts.llm_middleware_name, diff --git a/test-cluster.py b/test-cluster.py index 531892b..ec1773a 100644 --- a/test-cluster.py +++ b/test-cluster.py @@ -7,23 +7,33 @@ except ImportError: import time from threading import Thread +from llm_server.cluster.redis_cycle import load_backend_cycle -from llm_server.cluster.funcs.backend import get_best_backends -from llm_server.cluster.redis_config_cache import RedisClusterStore +from llm_server.cluster.backend import get_backends, get_a_cluster_backend from llm_server.cluster.worker import cluster_worker from llm_server.config.load import parse_backends, load_config +from llm_server.cluster.redis_config_cache import RedisClusterStore -success, config, msg = load_config('./config/config.yml').resolve().absolute() +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('config') +args = parser.parse_args() + +success, config, msg = load_config(args.config) cluster_config = RedisClusterStore('cluster_config') cluster_config.clear() cluster_config.load(parse_backends(config)) +on, off = get_backends() +load_backend_cycle('backend_cycler', on + off) t = Thread(target=cluster_worker) t.daemon = True t.start() while True: - x = get_best_backends() - print(x) + # online, offline = get_backends() + # print(online, offline) + # print(get_a_cluster_backend()) time.sleep(3) -- 2.34.1 From 114f36e70984f00b680b41f79662fd6909e2c68b Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sat, 30 Sep 2023 19:41:50 -0600 Subject: [PATCH 003/163] functional --- daemon.py | 25 +++++- llm_server/cluster/backend.py | 78 ++++++++++++------- llm_server/cluster/model_choices.py | 88 +++++++++++++++++++++ llm_server/cluster/redis_config_cache.py | 3 + llm_server/cluster/redis_cycle.py | 42 ++++++---- llm_server/cluster/worker.py | 37 +++++---- llm_server/config/config.py | 3 +- llm_server/config/load.py | 2 +- llm_server/custom_redis.py | 31 +++++++- llm_server/database/database.py | 14 ++-- llm_server/llm/__init__.py | 4 +- llm_server/llm/generator.py | 5 +- llm_server/llm/info.py | 15 ++++ llm_server/llm/llm_backend.py | 13 +++- llm_server/llm/oobabooga/ooba_backend.py | 9 ++- llm_server/llm/vllm/generate.py | 69 +++------------- llm_server/llm/vllm/info.py | 6 +- llm_server/llm/vllm/tokenize.py | 8 +- llm_server/llm/vllm/vllm_backend.py | 2 +- llm_server/opts.py | 2 +- llm_server/pre_fork.py | 12 --- llm_server/routes/helpers/client.py | 11 +-- llm_server/routes/ooba_request_handler.py | 4 +- llm_server/routes/queue.py | 39 +++++++--- llm_server/routes/request_handler.py | 27 ++++--- llm_server/routes/stats.py | 27 ++++++- llm_server/routes/v1/generate.py | 10 ++- llm_server/routes/v1/generate_stats.py | 95 ++++------------------- llm_server/routes/v1/info.py | 17 ++-- llm_server/workers/inferencer.py | 25 +++--- llm_server/workers/mainer.py | 22 +----- llm_server/workers/printer.py | 11 ++- llm_server/workers/threader.py | 4 +- other/vllm/vllm_api_server.py | 0 requirements.txt | 11 +-- server.py | 75 +++++++----------- templates/home.html | 35 ++++++++- test-cluster.py | 39 ---------- 38 files changed, 505 insertions(+), 415 deletions(-) create mode 100644 llm_server/cluster/model_choices.py mode change 100755 => 100644 other/vllm/vllm_api_server.py delete mode 100644 test-cluster.py diff --git a/daemon.py b/daemon.py index 82635f0..0fa3601 100644 --- a/daemon.py +++ b/daemon.py @@ -3,9 +3,14 @@ import sys import time from pathlib import Path -from llm_server.config.load import load_config +from llm_server.cluster.cluster_config import cluster_config +from llm_server.cluster.redis_cycle import redis_cycler_db +from llm_server.cluster.stores import redis_running_models +from llm_server.config.load import load_config, parse_backends from llm_server.custom_redis import redis from llm_server.database.create import create_db +from llm_server.routes.queue import priority_queue +from llm_server.routes.v1.generate_stats import generate_stats from llm_server.workers.threader import start_background script_path = os.path.dirname(os.path.realpath(__file__)) @@ -19,16 +24,30 @@ if __name__ == "__main__": flushed_keys = redis.flush() print('Flushed', len(flushed_keys), 'keys from Redis.') + redis_cycler_db.flushall() + redis_running_models.flush() + success, config, msg = load_config(config_path) if not success: print('Failed to load config:', msg) sys.exit(1) create_db() + + priority_queue.flush() + cluster_config.clear() + cluster_config.load(parse_backends(config)) + + print('Loading backend stats...') + generate_stats() + start_background() redis.set('daemon_started', 1) print('== Daemon Setup Complete ==\n') - while True: - time.sleep(3600) + try: + while True: + time.sleep(3600) + except KeyboardInterrupt: + redis.set('daemon_started', 0) diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py index 7b28e86..bb3e6d4 100644 --- a/llm_server/cluster/backend.py +++ b/llm_server/cluster/backend.py @@ -1,23 +1,34 @@ -from llm_server.cluster.redis_config_cache import RedisClusterStore -from llm_server.cluster.redis_cycle import redis_cycle +from llm_server.cluster.cluster_config import cluster_config +from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle from llm_server.cluster.stores import redis_running_models -from llm_server.llm.info import get_running_model +from llm_server.llm.generator import generator +from llm_server.llm.info import get_info -def test_backend(backend_url: str, mode: str): - running_model, err = get_running_model(backend_url, mode) - if not running_model: - return False - return True +def test_backend(backend_url: str, test_prompt: bool = False): + backend_info = cluster_config.get_backend(backend_url) + if test_prompt: + data = { + "prompt": "Test prompt", + "stream": False, + "temperature": 0, + "max_new_tokens": 16, + } + success, response, err = generator(data, backend_url, timeout=10) + if not success or not response or err: + return False, {} + i = get_info(backend_url, backend_info['mode']) + if not i.get('model'): + return False, {} + return True, i def get_backends(): - cluster_config = RedisClusterStore('cluster_config') backends = cluster_config.all() result = {} for k, v in backends.items(): b = cluster_config.get_backend(k) - status = b['online'] + status = b.get('online', False) priority = b['priority'] result[k] = {'status': status, 'priority': priority} online_backends = sorted( @@ -33,30 +44,43 @@ def get_backends(): return [url for url, info in online_backends], [url for url, info in offline_backends] -def get_a_cluster_backend(): +def get_a_cluster_backend(model=None): """ Get a backend from Redis. If there are no online backends, return None. + If `model` is not supplied, we will pick one ourself. """ - online, offline = get_backends() - cycled = redis_cycle('backend_cycler') - c = cycled.copy() - for i in range(len(cycled)): - if cycled[i] in offline: - del c[c.index(cycled[i])] - if len(c): - return c[0] + if model: + # First, determine if there are multiple backends hosting the same model. + backends_hosting_model = [i.decode('utf-8') for i in redis_running_models.smembers(model)] + + # If so, create an iterator for those backends + if len(backends_hosting_model): + add_backend_cycler(model, backends_hosting_model) + cycled = redis_cycle(model) + if len(cycled): + return cycled[0] + else: + # No backend hosting that model + return None else: - return None + online, _ = get_backends() + if len(online): + return online[0] def get_backends_from_model(model_name: str): - cluster_config = RedisClusterStore('cluster_config') - a = cluster_config.all() - matches = [] - for k, v in a.items(): - if v['online'] and v['running_model'] == model_name: - matches.append(k) - return matches + return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)] + + +# def verify_context_size(model_name:str): +# b = get_backends_from_model(model_name) +# for backend_url in b: +# backend_info = cluster_config.get_backend(backend_url) +# backend_info.get() + + +def get_running_models(): + return redis_running_models.keys() def purge_backend_from_running_models(backend_url: str): diff --git a/llm_server/cluster/model_choices.py b/llm_server/cluster/model_choices.py new file mode 100644 index 0000000..c9a94fd --- /dev/null +++ b/llm_server/cluster/model_choices.py @@ -0,0 +1,88 @@ +import numpy as np + +from llm_server import opts +from llm_server.cluster.backend import get_a_cluster_backend, get_backends_from_model, get_running_models +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis +from llm_server.routes.queue import priority_queue +from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers + + +# TODO: give this a better name! +def get_model_choices(regen: bool = False): + if not regen: + c = redis.getp('model_choices') + if c: + return c + + base_client_api = redis.get('base_client_api', dtype=str) + running_models = get_running_models() + model_choices = {} + for model in running_models: + b = get_backends_from_model(model) + + context_size = [] + avg_gen_per_worker = [] + for backend_url in b: + backend_info = cluster_config.get_backend(backend_url) + if backend_info.get('model_config'): + context_size.append(backend_info['model_config']['max_position_embeddings']) + if backend_info.get('average_generation_elapsed_sec'): + avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec']) + + active_gen_workers = get_active_gen_workers(model) + proompters_in_queue = priority_queue.len(model) + + if len(avg_gen_per_worker): + average_generation_elapsed_sec = np.average(avg_gen_per_worker) + else: + average_generation_elapsed_sec = 0 + estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers) + + if proompters_in_queue == 0 and active_gen_workers >= opts.concurrent_gens: + # There will be a wait if the queue is empty but prompts are processing, but we don't + # know how long. + estimated_wait_sec = f"less than {estimated_wait_sec} seconds" + else: + estimated_wait_sec = f"{estimated_wait_sec} seconds" + + model_choices[model] = { + 'client_api': f'https://{base_client_api}/v2/{model}', + 'ws_client_api': f'wss://{base_client_api}/v2/{model}/stream' if opts.enable_streaming else None, + 'openai_client_api': f'https://{base_client_api}/openai/v2/{model}' if opts.enable_openi_compatible_backend else 'disabled', + 'backend_count': len(b), + 'estimated_wait': estimated_wait_sec, + 'queued': proompters_in_queue, + 'processing': active_gen_workers, + 'avg_generation_time': average_generation_elapsed_sec + } + + if len(context_size): + model_choices[model]['context_size'] = min(context_size) + + model_choices = dict(sorted(model_choices.items())) + + default_backend = get_a_cluster_backend() + default_backend_info = cluster_config.get_backend(default_backend) + default_context_size = default_backend_info['model_config']['max_position_embeddings'] + default_average_generation_elapsed_sec = default_backend_info.get('average_generation_elapsed_sec') + default_active_gen_workers = redis.get(f'active_gen_workers:{default_backend}', dtype=int, default=0) + default_proompters_in_queue = priority_queue.len(default_backend_info['model']) + default_estimated_wait_sec = calculate_wait_time(default_average_generation_elapsed_sec, default_proompters_in_queue, default_backend_info['concurrent_gens'], default_active_gen_workers) + + default_backend_dict = { + 'client_api': f'https://{base_client_api}/v2', + 'ws_client_api': f'wss://{base_client_api}/v2' if opts.enable_streaming else None, + 'openai_client_api': f'https://{base_client_api}/openai/v2' if opts.enable_openi_compatible_backend else 'disabled', + 'estimated_wait': default_estimated_wait_sec, + 'queued': default_proompters_in_queue, + 'processing': default_active_gen_workers, + 'context_size': default_context_size, + 'hash': default_backend_info['hash'], + 'model': default_backend_info['model'], + 'avg_generation_time': default_average_generation_elapsed_sec + } + + redis.setp('model_choices', (model_choices, default_backend_dict)) + + return model_choices, default_backend_dict diff --git a/llm_server/cluster/redis_config_cache.py b/llm_server/cluster/redis_config_cache.py index ebb6099..3bab915 100644 --- a/llm_server/cluster/redis_config_cache.py +++ b/llm_server/cluster/redis_config_cache.py @@ -44,3 +44,6 @@ class RedisClusterStore: return result else: return {} + + # def get(self, name: str): + # return self.all().get(name) diff --git a/llm_server/cluster/redis_cycle.py b/llm_server/cluster/redis_cycle.py index 87893ba..7cff2c4 100644 --- a/llm_server/cluster/redis_cycle.py +++ b/llm_server/cluster/redis_cycle.py @@ -1,21 +1,35 @@ import redis -r = redis.Redis(host='localhost', port=6379, db=9) +redis_cycler_db = redis.Redis(host='localhost', port=6379, db=9) def redis_cycle(list_name): - while True: - pipe = r.pipeline() - pipe.lpop(list_name) - popped_element = pipe.execute()[0] - if popped_element is None: - return None - r.rpush(list_name, popped_element) - new_list = r.lrange(list_name, 0, -1) - return [x.decode('utf-8') for x in new_list] + """ + Emulates itertools.cycle() but returns the complete shuffled list. + :param list_name: + :return: + """ + to_move = redis_cycler_db.rpop(list_name) + if not to_move: + return [] + redis_cycler_db.lpush(list_name, to_move) + new_list = redis_cycler_db.lrange(list_name, 0, -1) + return [x.decode('utf-8') for x in new_list] -def load_backend_cycle(list_name: str, elements: list): - r.delete(list_name) - for element in elements: - r.rpush(list_name, element) +def add_backend_cycler(list_name: str, new_elements: list): + existing_elements = [i.decode('utf-8') for i in redis_cycler_db.lrange(list_name, 0, -1)] + existing_set = set(existing_elements) + + with redis_cycler_db.pipeline() as pipe: + # Add elements + for element in new_elements: + if element not in existing_set: + pipe.rpush(list_name, element) + + # Remove elements + for element in existing_set: + if element not in new_elements: + pipe.lrem(list_name, 0, element) + + pipe.execute() diff --git a/llm_server/cluster/worker.py b/llm_server/cluster/worker.py index bee280a..7956198 100644 --- a/llm_server/cluster/worker.py +++ b/llm_server/cluster/worker.py @@ -1,31 +1,42 @@ -from datetime import datetime +import time from threading import Thread -from llm_server.cluster.backend import purge_backend_from_running_models, test_backend +from llm_server.cluster.backend import test_backend from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.stores import redis_running_models -from llm_server.llm.info import get_running_model def cluster_worker(): + counter = 0 while True: + test_prompt = False + if counter % 4 == 0: + # Only send a test prompt every 120 seconds. + test_prompt = True threads = [] for n, v in cluster_config.all().items(): - thread = Thread(target=check_backend, args=(n, v)) + thread = Thread(target=check_backend, args=(n, v, test_prompt)) thread.start() threads.append(thread) for thread in threads: thread.join() + time.sleep(15) + counter += 1 -def check_backend(n, v): - # Check if backends are online - # TODO: also have test_backend() get the uptime - online = test_backend(v['backend_url'], v['mode']) +def check_backend(n, v, test_prompt): + online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt) + # purge_backend_from_running_models(n) if online: - running_model, err = get_running_model(v['backend_url'], v['mode']) - if not err: - cluster_config.set_backend_value(n, 'running_model', running_model) - purge_backend_from_running_models(n) - redis_running_models.sadd(running_model, n) + running_model = backend_info['model'] + for k, v in backend_info.items(): + cluster_config.set_backend_value(n, k, v) + redis_running_models.sadd(running_model, n) + else: + for model in redis_running_models.keys(): + redis_running_models.srem(model, n) + + # redis_running_models.srem(backend_info['model'], n) + # backend_cycler_store.lrem(backend_info['model'], 1, n) + cluster_config.set_backend_value(n, 'online', online) diff --git a/llm_server/config/config.py b/llm_server/config/config.py index b98ea49..645e81e 100644 --- a/llm_server/config/config.py +++ b/llm_server/config/config.py @@ -34,8 +34,9 @@ config_default_vars = { 'openai_moderation_enabled': True, 'netdata_root': None, 'show_backends': True, + 'cluster_workers': 30 } -config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] +config_required_vars = ['cluster', 'mode', 'llm_middleware_name'] mode_ui_names = { 'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 09fb127..9c2e7f3 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -26,7 +26,6 @@ def load_config(config_path): opts.log_prompts = config['log_prompts'] opts.concurrent_gens = config['concurrent_gens'] opts.frontend_api_client = config['frontend_api_client'] - opts.context_size = config['token_limit'] opts.show_num_prompts = config['show_num_prompts'] opts.show_uptime = config['show_uptime'] opts.cluster = config['cluster'] @@ -53,6 +52,7 @@ def load_config(config_path): opts.openai_silent_trim = config['openai_silent_trim'] opts.openai_moderation_enabled = config['openai_moderation_enabled'] opts.show_backends = config['show_backends'] + opts.cluster_workers = config['cluster_workers'] if opts.openai_expose_our_model and not opts.openai_api_key: print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') diff --git a/llm_server/custom_redis.py b/llm_server/custom_redis.py index b0db49e..d5d278f 100644 --- a/llm_server/custom_redis.py +++ b/llm_server/custom_redis.py @@ -9,17 +9,18 @@ from flask_caching import Cache from redis import Redis from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT -flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) +flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) ONE_MONTH_SECONDS = 2678000 -class RedisCustom: +class RedisCustom(Redis): """ A wrapper class to set prefixes to keys. """ def __init__(self, prefix, **kwargs): + super().__init__() self.redis = Redis(**kwargs) self.prefix = prefix try: @@ -108,6 +109,9 @@ class RedisCustom: ): return self.redis.hincrby(self._key(name), key, amount) + def zcard(self, name: KeyT): + return self.redis.zcard(self._key(name)) + def hdel(self, name: str, *keys: List): return self.redis.hdel(self._key(name), *keys) @@ -129,6 +133,9 @@ class RedisCustom: ): return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt) + def lpush(self, name: str, *values: FieldT): + return self.redis.lpush(self._key(name), *values) + def hset( self, name: str, @@ -164,6 +171,18 @@ class RedisCustom: def pipeline(self, transaction=True, shard_hint=None): return self.redis.pipeline(transaction, shard_hint) + def smembers(self, name: str): + return self.redis.smembers(self._key(name)) + + def spop(self, name: str, count: Optional[int] = None): + return self.redis.spop(self._key(name), count) + + def rpoplpush(self, src, dst): + return self.redis.rpoplpush(src, dst) + + def zpopmin(self, name: KeyT, count: Union[int, None] = None): + return self.redis.zpopmin(self._key(name), count) + def exists(self, *names: KeyT): n = [] for name in names: @@ -196,5 +215,13 @@ class RedisCustom: self.redis.delete(key) return flushed + def flushall(self, asynchronous: bool = ..., **kwargs) -> bool: + self.flush() + return True + + def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool: + self.flush() + return True + redis = RedisCustom('local_llm') diff --git a/llm_server/database/database.py b/llm_server/database/database.py index bf5f537..1dc2145 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -5,14 +5,14 @@ from threading import Thread import llm_server from llm_server import opts -from llm_server.custom_redis import redis +from llm_server.cluster.cluster_config import cluster_config from llm_server.database.conn import database from llm_server.llm.vllm import tokenize -def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens: int = None, is_error: bool = False): +def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens: int = None, is_error: bool = False): def background_task(): - nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens, is_error + nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error # Try not to shove JSON into the database. if isinstance(response, dict) and response.get('results'): response = response['results'][0]['text'] @@ -23,10 +23,10 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe except: pass - prompt_tokens = llm_server.llm.get_token_count(prompt) + prompt_tokens = llm_server.llm.get_token_count(prompt, backend_url) if not is_error: if not response_tokens: - response_tokens = llm_server.llm.get_token_count(response) + response_tokens = llm_server.llm.get_token_count(response, backend_url) else: response_tokens = None @@ -47,7 +47,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe if token: increment_token_uses(token) - running_model = redis.get('running_model', str, 'ERROR') + running_model = cluster_config.get_backend(backend_url).get('model') timestamp = int(time.time()) cursor = database.cursor() try: @@ -56,7 +56,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, - (ip, token, running_model, opts.mode, cluster_backend, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) + (ip, token, running_model, opts.mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) finally: cursor.close() diff --git a/llm_server/llm/__init__.py b/llm_server/llm/__init__.py index 6e39b42..3feb027 100644 --- a/llm_server/llm/__init__.py +++ b/llm_server/llm/__init__.py @@ -2,10 +2,10 @@ from llm_server.llm import oobabooga, vllm from llm_server.custom_redis import redis -def get_token_count(prompt: str): +def get_token_count(prompt: str, backend_url: str): backend_mode = redis.get('backend_mode', dtype=str) if backend_mode == 'vllm': - return vllm.tokenize(prompt) + return vllm.tokenize(prompt, backend_url) elif backend_mode == 'ooba': return oobabooga.tokenize(prompt) else: diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py index 42c3bb7..f05b37c 100644 --- a/llm_server/llm/generator.py +++ b/llm_server/llm/generator.py @@ -1,14 +1,13 @@ from llm_server import opts -def generator(request_json_body, cluster_backend): +def generator(request_json_body, cluster_backend, timeout: int = None): if opts.mode == 'oobabooga': # from .oobabooga.generate import generate # return generate(request_json_body) raise NotImplementedError elif opts.mode == 'vllm': from .vllm.generate import generate - r = generate(request_json_body, cluster_backend) - return r + return generate(request_json_body, cluster_backend, timeout=timeout) else: raise Exception diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py index 117da3f..d1218e2 100644 --- a/llm_server/llm/info.py +++ b/llm_server/llm/info.py @@ -20,3 +20,18 @@ def get_running_model(backend_url: str, mode: str): return False, e else: raise Exception + + +def get_info(backend_url: str, mode: str): + if mode == 'ooba': + return {} + # raise NotImplementedError + elif mode == 'vllm': + try: + r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout) + j = r.json() + except Exception as e: + return {} + return j + else: + raise Exception diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index e8268b1..e69f8fc 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -3,13 +3,17 @@ from typing import Tuple, Union import flask from llm_server import opts -from llm_server.llm import get_token_count +from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis +from llm_server.llm import get_token_count class LLMBackend: _default_params: dict + def __init__(self, backend_url: str): + self.backend_url = backend_url + def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers): raise NotImplementedError @@ -38,8 +42,9 @@ class LLMBackend: return True, None def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: - prompt_len = get_token_count(prompt) - if prompt_len > opts.context_size - 10: + prompt_len = get_token_count(prompt, self.backend_url) + token_limit = cluster_config.get_backend(self.backend_url)['model_config']['max_position_embeddings'] + if prompt_len > token_limit - 10: model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str) - return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size' + return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {model_name}). Please lower your context size' return True, None diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index 78f2190..fe450bf 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -1,9 +1,9 @@ from flask import jsonify +from llm_server.custom_redis import redis from ..llm_backend import LLMBackend from ...database.database import log_prompt from ...helpers import safe_list_get -from llm_server.custom_redis import redis from ...routes.helpers.client import format_sillytavern_err from ...routes.helpers.http import validate_json @@ -33,7 +33,7 @@ class OobaboogaBackend(LLMBackend): error_msg = 'Unknown error.' else: error_msg = error_msg.strip('.') + '.' - backend_response = format_sillytavern_err(error_msg, 'error') + backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url) log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True) return jsonify({ 'code': 500, @@ -50,7 +50,8 @@ class OobaboogaBackend(LLMBackend): backend_err = True backend_response = format_sillytavern_err( f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.', - 'error') + error_type='error', + backend_url=self.backend_url) response_json_body['results'][0]['text'] = backend_response if not backend_err: @@ -61,7 +62,7 @@ class OobaboogaBackend(LLMBackend): **response_json_body }), 200 else: - backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') + backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url) log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True) return jsonify({ 'code': 500, diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index caac445..72b0243 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -24,57 +24,6 @@ def prepare_json(json_data: dict): return json_data -def transform_to_text(json_request, api_response): - """ - This is to convert a streaming request to a non-streamed request. Don't think this is nessesary. - :param json_request: - :param api_response: - :return: - """ - prompt = transform_prompt_to_text(json_request['messages']) - text = '' - finish_reason = None - for line in api_response.split('\n'): - if line.startswith('data:'): - try: - data = json.loads(line[5:].strip()) - except json.decoder.JSONDecodeError: - break - if 'choices' in data: - for choice in data['choices']: - if 'delta' in choice and 'content' in choice['delta']: - text += choice['delta']['content'] - if data['choices'][0]['finish_reason']: - finish_reason = data['choices'][0]['finish_reason'] - - prompt_tokens = len(llm_server.llm.get_token_count(prompt)) - completion_tokens = len(llm_server.llm.get_token_count(text)) - running_model = redis.get('running_model', 'ERROR', dtype=str) - - # https://platform.openai.com/docs/api-reference/making-requests?lang=python - return { - "id": str(uuid4()), - "object": "chat.completion", - "created": int(time.time()), - "model": running_model, - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens - }, - "choices": [ - { - "message": { - "role": "assistant", - "content": text - }, - "finish_reason": finish_reason, - "index": 0 - } - ] - } - - def transform_prompt_to_text(prompt: list): text = '' for item in prompt: @@ -82,26 +31,26 @@ def transform_prompt_to_text(prompt: list): return text.strip('\n') -def handle_blocking_request(json_data: dict, cluster_backend): +def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10): try: - r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout) except requests.exceptions.ReadTimeout: - print(f'Failed to reach VLLM inference endpoint - request to backend timed out') + # print(f'Failed to reach VLLM inference endpoint - request to backend timed out') return False, None, 'Request to backend timed out' except Exception as e: - print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') + # print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') return False, None, 'Request to backend encountered error' if r.status_code != 200: - print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}') + # print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}') return False, r, f'Backend returned {r.status_code}' return True, r, None -def generate(json_data: dict, cluster_backend): +def generate(json_data: dict, cluster_backend, timeout: int = None): if json_data.get('stream'): try: - return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout) except Exception as e: - print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') + return False else: - return handle_blocking_request(json_data, cluster_backend) + return handle_blocking_request(json_data, cluster_backend, timeout=timeout) diff --git a/llm_server/llm/vllm/info.py b/llm_server/llm/vllm/info.py index 996c614..0142301 100644 --- a/llm_server/llm/vllm/info.py +++ b/llm_server/llm/vllm/info.py @@ -1,3 +1,7 @@ +import requests + +from llm_server import opts + vllm_info = """

Important: This endpoint is running vllm and not all Oobabooga parameters are supported.

Supported Parameters: """ \ No newline at end of file +""" diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index a698fd6..747a8b8 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -2,19 +2,21 @@ import requests import tiktoken from llm_server import opts +from llm_server.cluster.cluster_config import cluster_config -def tokenize(prompt: str) -> int: +def tokenize(prompt: str, backend_url: str) -> int: if not prompt: # The tokenizers have issues when the prompt is None. return 0 tokenizer = tiktoken.get_encoding("cl100k_base") + token_limit = cluster_config.get_backend(backend_url)['model_config']['max_position_embeddings'] # First we tokenize it locally to determine if it's worth sending it to the backend. initial_estimate = len(tokenizer.encode(prompt)) - if initial_estimate <= opts.context_size + 200: + if initial_estimate <= token_limit + 200: try: - r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + r = requests.post(f'{backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) j = r.json() return j['length'] except Exception as e: diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index 3db99d9..a28e59a 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -20,7 +20,7 @@ class VLLMBackend(LLMBackend): backend_response = '' log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, - response_tokens=response_json_body.get('details', {}).get('generated_tokens')) + response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url) return jsonify({'results': [{'text': backend_response}]}), 200 diff --git a/llm_server/opts.py b/llm_server/opts.py index 0d13979..bbd6201 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -5,7 +5,6 @@ concurrent_gens = 3 mode = 'oobabooga' backend_url = None -context_size = 5555 max_new_tokens = 500 auth_required = False log_prompts = False @@ -38,3 +37,4 @@ openai_silent_trim = False openai_moderation_enabled = True cluster = {} show_backends = True +cluster_workers = 30 diff --git a/llm_server/pre_fork.py b/llm_server/pre_fork.py index 900210c..6e8c1ad 100644 --- a/llm_server/pre_fork.py +++ b/llm_server/pre_fork.py @@ -1,21 +1,9 @@ import sys -from redis import Redis - from llm_server.custom_redis import redis -from llm_server.routes.v1.generate_stats import generate_stats def server_startup(s): if not redis.get('daemon_started', dtype=bool): print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') sys.exit(1) - - # Flush the RedisPriorityQueue database. - queue_redis = Redis(host='localhost', port=6379, db=15) - for key in queue_redis.scan_iter('*'): - queue_redis.delete(key) - - # Cache the initial stats - print('Loading backend stats...') - generate_stats() diff --git a/llm_server/routes/helpers/client.py b/llm_server/routes/helpers/client.py index a914362..040a129 100644 --- a/llm_server/routes/helpers/client.py +++ b/llm_server/routes/helpers/client.py @@ -2,13 +2,14 @@ from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis -def format_sillytavern_err(msg: str, backend_url: str, level: str = 'info'): - cluster_backend_hash = cluster_config.get_backend_handler(backend_url)['hash'] +def format_sillytavern_err(msg: str, backend_url: str = 'none', error_type: str = 'info'): + cluster_backend_hash = cluster_config.get_backend(backend_url)['hash'] http_host = redis.get('http_host', dtype=str) return f"""``` === MESSAGE FROM LLM MIDDLEWARE AT {http_host} === --> {level.upper()} <- +-> {error_type.upper()} <- {msg} - -BACKEND HASH: {cluster_backend_hash} +``` +``` +BACKEND: {cluster_backend_hash} ```""" diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index 8e0036c..a272960 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler): msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' backend_response = self.handle_error(msg) if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.cluster_backend, is_error=True) + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) return backend_response[0], 200 # We only return the response from handle_error(), not the error code def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: @@ -40,7 +40,7 @@ class OobaRequestHandler(RequestHandler): # TODO: how to format this response_msg = error_msg else: - response_msg = format_sillytavern_err(error_msg, error_type, self.cluster_backend) + response_msg = format_sillytavern_err(error_msg, error_type=error_type, backend_url=self.backend_url) return jsonify({ 'results': [{'text': response_msg}] diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 09ed06c..f058298 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -6,7 +6,7 @@ from uuid import uuid4 from redis import Redis from llm_server import opts -from llm_server.custom_redis import redis +from llm_server.custom_redis import RedisCustom, redis def increment_ip_count(client_ip: str, redis_key): @@ -20,12 +20,12 @@ def decrement_ip_count(client_ip: str, redis_key): class RedisPriorityQueue: - def __init__(self): - self.redis = Redis(host='localhost', port=6379, db=15) + def __init__(self, name: str = 'priority_queue', db: int = 12): + self.redis = RedisCustom(name, db=db) self.pubsub = self.redis.pubsub() self.pubsub.subscribe('events') - def put(self, item, priority): + def put(self, item, priority, selected_model): event = DataEvent() # Check if the IP is already in the dictionary and if it has reached the limit @@ -36,7 +36,7 @@ class RedisPriorityQueue: print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') return None # reject the request - self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority}) + self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority}) self.increment_ip_count(item[1], 'queued_ip_count') return event @@ -61,12 +61,23 @@ class RedisPriorityQueue: def __len__(self): return self.redis.zcard('queue') + def len(self, model_name): + count = 0 + for key in self.redis.zrange('queue', 0, -1): + item = json.loads(key) + if item[2] == model_name: + count += 1 + return count + def get_queued_ip_count(self, client_ip: str): q = self.redis.hget('queued_ip_count', client_ip) if not q: return 0 return 0 + def flush(self): + self.redis.flush() + class DataEvent: def __init__(self, event_id=None): @@ -87,12 +98,16 @@ class DataEvent: priority_queue = RedisPriorityQueue() -def incr_active_workers(): - redis.incr('active_gen_workers') +def incr_active_workers(selected_model: str, backend_url: str): + redis.incr(f'active_gen_workers:{selected_model}') + redis.incr(f'active_gen_workers:{backend_url}') -def decr_active_workers(): - redis.decr('active_gen_workers') - new_count = redis.get('active_gen_workers', 0, dtype=int) - if new_count < 0: - redis.set('active_gen_workers', 0) +def decr_active_workers(selected_model: str, backend_url: str): + redis.decr(f'active_gen_workers:{selected_model}') + if redis.get(f'active_gen_workers:{selected_model}', 0, dtype=int) < 0: + redis.set(f'active_gen_workers:{selected_model}', 0) + + redis.decr(f'active_gen_workers:{backend_url}') + if redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) < 0: + redis.set(f'active_gen_workers:{backend_url}', 0) diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index ecae085..83f510a 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -15,13 +15,13 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token from llm_server.routes.helpers.http import require_api_key, validate_json -from llm_server.routes.queue import priority_queue +from llm_server.routes.queue import RedisPriorityQueue, priority_queue DEFAULT_PRIORITY = 9999 class RequestHandler: - def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None): + def __init__(self, incoming_request: flask.Request, selected_model: str, incoming_json: Union[dict, str] = None): self.request = incoming_request self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' @@ -37,11 +37,12 @@ class RequestHandler: self.client_ip = self.get_client_ip() self.token = self.get_auth_token() self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit() - self.cluster_backend = get_a_cluster_backend() - self.cluster_backend_info = cluster_config.get_backend(self.cluster_backend) - self.backend = get_backend_handler(self.cluster_backend) + self.backend_url = get_a_cluster_backend(selected_model) + self.cluster_backend_info = cluster_config.get_backend(self.backend_url) + self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url) self.parameters = None self.used = False + self.selected_model = selected_model redis.zadd('recent_prompters', {self.client_ip: time.time()}) def get_auth_token(self): @@ -123,7 +124,7 @@ class RequestHandler: backend_response = self.handle_error(combined_error_message, 'Validation Error') if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.cluster_backend, is_error=True) + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True) return False, backend_response return True, (None, 0) @@ -135,14 +136,16 @@ class RequestHandler: request_valid, invalid_response = self.validate_request(prompt, do_log=True) if not request_valid: return (False, None, None, 0), invalid_response - event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.cluster_backend), self.token_priority) + event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.backend_url), self.token_priority, self.selected_model) else: event = None if not event: return (False, None, None, 0), self.handle_ratelimited() + # TODO: add wait timeout success, response, error_msg = event.wait() + end_time = time.time() elapsed_time = end_time - self.start_time @@ -164,7 +167,7 @@ class RequestHandler: else: error_msg = error_msg.strip('.') + '.' backend_response = self.handle_error(error_msg) - log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True) + log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True) return (False, None, None, 0), backend_response # =============================================== @@ -184,7 +187,7 @@ class RequestHandler: if return_json_err: error_msg = 'The backend did not return valid JSON.' backend_response = self.handle_error(error_msg) - log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True) + log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True) return (False, None, None, 0), backend_response # =============================================== @@ -218,11 +221,11 @@ class RequestHandler: raise NotImplementedError -def get_backend_handler(mode): +def get_backend_handler(mode, backend_url: str): if mode == 'oobabooga': - return OobaboogaBackend() + return OobaboogaBackend(backend_url) elif mode == 'vllm': - return VLLMBackend() + return VLLMBackend(backend_url) else: raise Exception diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index b4dea54..9e1f291 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -1,6 +1,7 @@ from datetime import datetime from llm_server.custom_redis import redis +from llm_server.helpers import round_up_base server_start_time = datetime.now() @@ -14,10 +15,32 @@ def get_total_proompts(): return count -def get_active_gen_workers(): - active_gen_workers = redis.get('active_gen_workers') +def get_active_gen_workers(selected_model: str = None, ): + active_gen_workers = redis.get(f'active_gen_workers:{selected_model}') if active_gen_workers is None: count = 0 else: count = int(active_gen_workers) return count + + +def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers): + if active_gen_workers < concurrent_gens: + return 0 + elif active_gen_workers >= concurrent_gens: + # Calculate how long it will take to complete the currently running gens and the queued requests. + # If the proompters in the queue are equal to the number of workers, just use the calculated generation time. + # Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round + # that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally, + # Add gen_time_calc to the time to account for the currently running generations. + # This assumes that all active workers will finish at the same time, which is unlikely. + # Regardless, this is the most accurate estimate we can get without tracking worker elapsed times. + proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \ + else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc) + return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0 + elif proompters_in_queue == 0 and active_gen_workers == 0: + # No queue, no workers + return 0 + else: + # No queue + return gen_time_calc diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 715288f..39db078 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -3,18 +3,20 @@ import traceback from flask import jsonify, request from . import bp -from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler +from ...cluster.backend import get_a_cluster_backend +from ...cluster.cluster_config import cluster_config -@bp.route('/generate', methods=['POST']) -def generate(): +@bp.route('/v1/generate', methods=['POST']) +@bp.route('//v1/generate', methods=['POST']) +def generate(model_name=None): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: - handler = OobaRequestHandler(request) + handler = OobaRequestHandler(request, model_name) try: return handler.handle_request() except Exception: diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index 66dd316..30e0967 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -2,74 +2,32 @@ import time from datetime import datetime from llm_server import opts -from llm_server.cluster.backend import get_a_cluster_backend, test_backend +from llm_server.cluster.backend import get_a_cluster_backend from llm_server.cluster.cluster_config import cluster_config +from llm_server.cluster.model_choices import get_model_choices from llm_server.custom_redis import redis from llm_server.database.database import get_distinct_ips_24h, sum_column -from llm_server.helpers import deep_sort, round_up_base -from llm_server.llm.info import get_running_model -from llm_server.routes.queue import priority_queue -from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time - - -def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers): - if active_gen_workers < concurrent_gens: - return 0 - elif active_gen_workers >= concurrent_gens: - # Calculate how long it will take to complete the currently running gens and the queued requests. - # If the proompters in the queue are equal to the number of workers, just use the calculated generation time. - # Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round - # that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally, - # Add gen_time_calc to the time to account for the currently running generations. - # This assumes that all active workers will finish at the same time, which is unlikely. - # Regardless, this is the most accurate estimate we can get without tracking worker elapsed times. - proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \ - else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc) - return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0 - elif proompters_in_queue == 0 and active_gen_workers == 0: - # No queue, no workers - return 0 - else: - # No queue - return gen_time_calc +from llm_server.helpers import deep_sort +from llm_server.routes.stats import get_total_proompts, server_start_time def generate_stats(regen: bool = False): if not regen: - c = redis.get('proxy_stats', dtype=dict) + c = redis.getp('proxy_stats') if c: return c default_backend_url = get_a_cluster_backend() default_backend_info = cluster_config.get_backend(default_backend_url) if not default_backend_info.get('mode'): - # TODO: remove - print('DAEMON NOT FINISHED STARTING') return base_client_api = redis.get('base_client_api', dtype=str) proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf')) - average_generation_elapsed_sec = redis.get('average_generation_elapsed_sec', 0) - - online = test_backend(default_backend_url, default_backend_info['mode']) - if online: - running_model, err = get_running_model(default_backend_url, default_backend_info['mode']) - cluster_config.set_backend_value(default_backend_url, 'running_model', running_model) - else: - running_model = None - - active_gen_workers = get_active_gen_workers() - proompters_in_queue = len(priority_queue) - - # This is so wildly inaccurate it's disabled. - # estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0) - - # TODO: make this for the currently selected backend - estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers) output = { 'default': { - 'model': running_model, - 'backend': default_backend_info['hash'], + 'model': default_backend_info['model'], + 'backend': default_backend_url, }, 'stats': { 'proompters': { @@ -78,21 +36,14 @@ def generate_stats(regen: bool = False): }, 'proompts_total': get_total_proompts() if opts.show_num_prompts else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None, - 'average_generation_elapsed_sec': int(average_generation_elapsed_sec), # 'estimated_avg_tps': estimated_avg_tps, 'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None, 'num_backends': len(cluster_config.all()) if opts.show_backends else None, }, - 'online': online, 'endpoints': { 'blocking': f'https://{base_client_api}', 'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, }, - 'queue': { - 'processing': active_gen_workers, - 'queued': proompters_in_queue, - 'estimated_wait_sec': int(estimated_wait_sec), - }, 'timestamp': int(time.time()), 'config': { 'gatekeeper': 'none' if opts.auth_required is False else 'token', @@ -106,42 +57,30 @@ def generate_stats(regen: bool = False): 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, } + # TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself + if opts.show_backends: for backend_url, v in cluster_config.all().items(): backend_info = cluster_config.get_backend(backend_url) if not backend_info['online']: continue - - # TODO: have this fetch the data from VLLM which will display GPU utalization - # if opts.netdata_root: - # netdata_stats = {} - # power_states = get_power_states() - # for gpu, power_state in power_states.items(): - # netdata_stats[gpu] = { - # 'power_state': power_state, - # # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu'))) - # } - # else: - # netdata_stats = {} - netdata_stats = {} - - # TODO: use value returned by VLLM backend here - # backend_uptime = int((datetime.now() - backend_info['start_time']).total_seconds()) if opts.show_uptime else None - backend_uptime = -1 - + backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None output['backend_info'][backend_info['hash']] = { 'uptime': backend_uptime, - # 'context_size': opts.context_size, - 'model': opts.manual_model_name if opts.manual_model_name else backend_info.get('running_model', 'ERROR'), + 'max_tokens': backend_info['model_config']['max_position_embeddings'], + 'model': backend_info['model'], 'mode': backend_info['mode'], - 'nvidia': netdata_stats + 'nvidia': backend_info['nvidia'], } else: output['backend_info'] = {} + output['default'] = get_model_choices(regen=True)[1] + result = deep_sort(output) # It may take a bit to get the base client API, so don't cache until then. if base_client_api: - redis.set_dict('proxy_stats', result) # Cache with no expiry + redis.setp('proxy_stats', result) + return result diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index 90778e5..355b415 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -10,13 +10,9 @@ from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends from ...cluster.cluster_config import cluster_config -@bp.route('/model', methods=['GET']) -@bp.route('//model', methods=['GET']) +@bp.route('/v1/model', methods=['GET']) +@bp.route('//v1/model', methods=['GET']) def get_model(model_name=None): - if not model_name: - b = get_a_cluster_backend() - model_name = cluster_config.get_backend(b)['running_model'] - # We will manage caching ourself since we don't want to cache # when the backend is down. Also, Cloudflare won't cache 500 errors. cache_key = 'model_cache::' + request.url @@ -25,6 +21,9 @@ def get_model(model_name=None): if cached_response: return cached_response + if not model_name: + model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model') + if not is_valid_model(model_name): response = jsonify({ 'code': 400, @@ -32,7 +31,6 @@ def get_model(model_name=None): }), 400 else: num_backends = len(get_backends_from_model(model_name)) - response = jsonify({ 'result': opts.manual_model_name if opts.manual_model_name else model_name, 'model_backend_count': num_backends, @@ -47,7 +45,8 @@ def get_model(model_name=None): @requires_auth def get_backend(): online, offline = get_backends() - result = [] + result = {} for i in online + offline: - result.append(cluster_config.get_backend(i)) + info = cluster_config.get_backend(i) + result[info['hash']] = info return jsonify(result), 200 diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index 626e34b..e92052e 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -1,7 +1,7 @@ import threading import time -from llm_server import opts +from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis from llm_server.llm.generator import generator from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue @@ -9,12 +9,16 @@ from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip def worker(): while True: - need_to_wait() - (request_json_body, client_ip, token, parameters, cluster_backend), event_id = priority_queue.get() - need_to_wait() + (request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get() + if not selected_model: + selected_model = cluster_config.get_backend(backend_url)['model'] + + # This wait time is "invisible", meaning the worker may as + # well be still waiting to get an item from the queue. + need_to_wait(backend_url) increment_ip_count(client_ip, 'processing_ips') - incr_active_workers() + incr_active_workers(selected_model, backend_url) if not request_json_body: # This was a dummy request from the websocket handler. @@ -22,12 +26,12 @@ def worker(): continue try: - success, response, error_msg = generator(request_json_body, cluster_backend) + success, response, error_msg = generator(request_json_body, backend_url) event = DataEvent(event_id) event.set((success, response, error_msg)) finally: decrement_ip_count(client_ip, 'processing_ips') - decr_active_workers() + decr_active_workers(selected_model, backend_url) def start_workers(num_workers: int): @@ -40,11 +44,12 @@ def start_workers(num_workers: int): print(f'Started {i} inference workers.') -def need_to_wait(): +def need_to_wait(backend_url: str): # We need to check the number of active workers since the streaming endpoint may be doing something. - active_workers = redis.get('active_gen_workers', 0, dtype=int) + active_workers = redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) + concurrent_gens = cluster_config.get_backend(backend_url).get('concurrent_gens', 1) s = time.time() - while active_workers >= opts.concurrent_gens: + while active_workers >= concurrent_gens: time.sleep(0.01) e = time.time() if e - s > 0.5: diff --git a/llm_server/workers/mainer.py b/llm_server/workers/mainer.py index 447046f..ca82d60 100644 --- a/llm_server/workers/mainer.py +++ b/llm_server/workers/mainer.py @@ -5,7 +5,7 @@ from llm_server.cluster.backend import get_a_cluster_backend, get_backends from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis from llm_server.database.database import weighted_average_column_for_model -from llm_server.llm.info import get_running_model +from llm_server.llm.info import get_info, get_running_model def main_background_thread(): @@ -14,8 +14,9 @@ def main_background_thread(): for backend_url in online: backend_info = cluster_config.get_backend(backend_url) backend_mode = backend_info['mode'] - running_model, err = get_running_model(backend_url, backend_mode) - if err: + backend_info = get_info(backend_url, backend_mode) + running_model = backend_info.get('model') + if not running_model: continue average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode) @@ -25,21 +26,6 @@ def main_background_thread(): cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens) if average_generation_elapsed_sec and average_output_tokens: cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps) - - default_backend_url = get_a_cluster_backend() - default_backend_info = cluster_config.get_backend(default_backend_url) - default_backend_mode = default_backend_info['mode'] - default_running_model, err = get_running_model(default_backend_url, default_backend_mode) - if err: - continue - - default_average_generation_elapsed_sec, default_average_output_tokens, default_estimated_avg_tps = calc_stats_for_backend(default_running_model, default_running_model, default_backend_mode) - if default_average_generation_elapsed_sec: - redis.set('average_generation_elapsed_sec', default_average_generation_elapsed_sec) - if default_average_output_tokens: - redis.set('average_output_tokens', default_average_output_tokens) - if default_average_generation_elapsed_sec and default_average_output_tokens: - redis.set('estimated_avg_tps', default_estimated_avg_tps) time.sleep(30) diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py index 6a33835..ed6ff65 100644 --- a/llm_server/workers/printer.py +++ b/llm_server/workers/printer.py @@ -1,6 +1,7 @@ import logging import time +from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis from llm_server.routes.queue import priority_queue @@ -17,9 +18,11 @@ if not logger.handlers: def console_printer(): time.sleep(3) while True: - processing = redis.hkeys('processing_ips') + processing = redis.keys('active_gen_workers:http*') # backends always start with http processing_count = 0 - for ip in processing: - processing_count += int(redis.hget('processing_ips', ip)) - logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}') + if len(processing): + for k in processing: + processing_count += redis.get(k, default=0, dtype=int) + backends = [k for k, v in cluster_config.all().items() if v['online']] + logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}') time.sleep(10) diff --git a/llm_server/workers/threader.py b/llm_server/workers/threader.py index 83bac2d..0c82559 100644 --- a/llm_server/workers/threader.py +++ b/llm_server/workers/threader.py @@ -15,11 +15,11 @@ from llm_server.workers.recenter import recent_prompters_thread def cache_stats(): while True: generate_stats(regen=True) - time.sleep(1) + time.sleep(5) def start_background(): - start_workers(opts.concurrent_gens) + start_workers(opts.cluster_workers) t = Thread(target=main_background_thread) t.daemon = True diff --git a/other/vllm/vllm_api_server.py b/other/vllm/vllm_api_server.py old mode 100755 new mode 100644 diff --git a/requirements.txt b/requirements.txt index 7b49eed..6057884 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,20 +1,17 @@ flask~=2.3.3 -flask_cors pyyaml~=6.0.1 flask_caching requests~=2.31.0 tiktoken~=0.5.0 -gunicorn gevent~=23.9.0.post1 -async-timeout -flask-sock -uvicorn~=0.23.2 -fastapi~=0.103.1 torch~=2.0.1 PyMySQL~=1.1.0 -DBUtils~=3.0.3 simplejson~=3.19.1 websockets~=11.0.3 basicauth~=1.0.0 openai~=0.28.0 urllib3~=2.0.4 +flask-sock==0.6.0 +gunicorn==21.2.0 +redis==5.0.1 +git+https://github.com/vllm-project/vllm \ No newline at end of file diff --git a/server.py b/server.py index 0214b49..699290f 100644 --- a/server.py +++ b/server.py @@ -1,5 +1,3 @@ -from llm_server.cluster.cluster_config import cluster_config - try: import gevent.monkey @@ -14,10 +12,10 @@ from pathlib import Path import simplejson as json from flask import Flask, jsonify, render_template, request -from llm_server.cluster.backend import get_a_cluster_backend, get_backends -from llm_server.cluster.redis_cycle import load_backend_cycle +from llm_server.cluster.cluster_config import cluster_config +from llm_server.cluster.model_choices import get_model_choices from llm_server.config.config import mode_ui_names -from llm_server.config.load import load_config, parse_backends +from llm_server.config.load import load_config from llm_server.database.conn import database from llm_server.database.create import create_db from llm_server.pre_fork import server_startup @@ -26,10 +24,7 @@ from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.sock import init_socketio -# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. -# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail -# TODO: implement background thread to test backends via sending test prompts -# TODO: if backend fails request, mark it as down +# TODO: per-backend workers # TODO: allow setting concurrent gens per-backend # TODO: set the max tokens to that of the lowest backend # TODO: implement RRD backend loadbalancer option @@ -42,6 +37,7 @@ from llm_server.sock import init_socketio # TODO: have VLLM report context size, uptime # Lower priority +# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens # TODO: the estiamted wait time lags behind the stats # TODO: simulate OpenAI error messages regardless of endpoint @@ -69,12 +65,11 @@ from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info from llm_server.custom_redis import flask_cache from llm_server.llm import redis -from llm_server.routes.stats import get_active_gen_workers from llm_server.routes.v1.generate_stats import generate_stats app = Flask(__name__) init_socketio(app) -app.register_blueprint(bp, url_prefix='/api/v1/') +app.register_blueprint(bp, url_prefix='/api/v2/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') flask_cache.init_app(app) flask_cache.clear() @@ -94,37 +89,23 @@ if not success: database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) create_db() -cluster_config.clear() -cluster_config.load(parse_backends(config)) -on, off = get_backends() -load_backend_cycle('backend_cycler', on + off) - @app.route('/') @app.route('/api') @app.route('/api/openai') @flask_cache.cached(timeout=10) def home(): - # Use the default backend - backend_url = get_a_cluster_backend() - if backend_url: - backend_info = cluster_config.get_backend(backend_url) - stats = generate_stats(backend_url) - else: - backend_info = stats = None + base_client_api = redis.get('base_client_api', dtype=str) + stats = generate_stats() - if not stats['online']: - running_model = estimated_wait_sec = 'offline' - else: - running_model = backend_info['running_model'] + model_choices, default_backend_info = get_model_choices() - active_gen_workers = get_active_gen_workers() - if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: - # There will be a wait if the queue is empty but prompts are processing, but we don't - # know how long. - estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds" - else: - estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds" + if default_backend_info['queued'] == 0 and default_backend_info['queued'] >= opts.concurrent_gens: + # There will be a wait if the queue is empty but prompts are processing, but we don't + # know how long. + default_estimated_wait_sec = f"less than {default_backend_info['estimated_wait']} seconds" + else: + default_estimated_wait_sec = f"{default_backend_info['estimated_wait']} seconds" if len(config['analytics_tracking_code']): analytics_tracking_code = f"" @@ -137,39 +118,35 @@ def home(): info_html = '' mode_info = '' - using_vllm = False for k, v in cluster_config.all().items(): - if v['mode'] == vllm: - using_vllm = True + if v['mode'] == 'vllm': + mode_info = vllm_info break - if using_vllm == 'vllm': - mode_info = vllm_info - - base_client_api = redis.get('base_client_api', dtype=str) - return render_template('home.html', llm_middleware_name=opts.llm_middleware_name, analytics_tracking_code=analytics_tracking_code, info_html=info_html, - current_model=opts.manual_model_name if opts.manual_model_name else running_model, - client_api=f'https://{base_client_api}', - ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, - estimated_wait=estimated_wait_sec, + default_model=default_backend_info['model'], + default_active_gen_workers=default_backend_info['processing'], + default_proompters_in_queue=default_backend_info['queued'], + current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model, + client_api=f'https://{base_client_api}/v2', + ws_client_api=f'wss://{base_client_api}/v2/stream' if opts.enable_streaming else 'disabled', + default_estimated_wait=default_estimated_wait_sec, mode_name=mode_ui_names[opts.mode][0], api_input_textbox=mode_ui_names[opts.mode][1], streaming_input_textbox=mode_ui_names[opts.mode][2], - context_size=opts.context_size, + default_context_size=default_backend_info['context_size'], stats_json=json.dumps(stats, indent=4, ensure_ascii=False), extra_info=mode_info, openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled', expose_openai_system_prompt=opts.expose_openai_system_prompt, enable_streaming=opts.enable_streaming, + model_choices=model_choices ) -# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend - @app.route('/') @app.route('//') def fallback(first=None, rest=None): diff --git a/templates/home.html b/templates/home.html index 4b9c153..fb6f3e9 100644 --- a/templates/home.html +++ b/templates/home.html @@ -65,6 +65,10 @@ .hidden { display: none; } + + .header-workers { + font-weight: normal; + } @@ -76,8 +80,12 @@

{{ llm_middleware_name }}

-

Current Model: {{ current_model }}

-

Estimated Wait Time: {{ estimated_wait }}

+

Current Model: {{ default_model }}

+

+ Estimated Wait Time: {{ default_estimated_wait }}
+ Processing: {{ default_active_gen_workers }}
+ Queued: {{ default_proompters_in_queue }} +


Client API URL: {{ client_api }}

Streaming API URL: {{ ws_client_api if enable_streaming else 'Disabled' }}

@@ -101,7 +109,7 @@ API key textbox.
  • Click Connect to test the connection.
  • -
  • Open your preset config and set Context Size to {{ context_size }}.
  • +
  • Open your preset config and set Context Size to {{ default_context_size }}.
  • Follow this guide to get set up: rentry.org/freellamas
  • @@ -119,9 +127,30 @@
    + {% for key, value in model_choices.items() %} +
    +

    {{ key }} - {{ value.backend_count }} workers

    +

    + Estimated Wait Time: {{ value.estimated_wait }}
    + Processing: {{ value.processing }}
    + Queued: {{ value.queued }}
    +

    +

    + Client API URL: {{ value.client_api }}
    + Streaming API URL: {{ value.ws_client_api }}
    + OpenAI-Compatible API URL: {{ value.openai_client_api }} +

    +

    Context Size: {{ value.context_size }}

    +

    Average Generation Time: {{ value.avg_generation_time | int }} seconds

    +
    +
    + {% endfor %} + +