From e7b57cad7b6a5c4a93b2066de2058553da349bb6 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Thu, 28 Sep 2023 18:40:24 -0600 Subject: [PATCH] set up cluster config and basic background workers --- daemon.py | 2 +- llm_server/cluster/__init__.py | 0 llm_server/cluster/datastore.py | 0 llm_server/cluster/funcs/__init__.py | 0 llm_server/cluster/funcs/backend.py | 26 +++++++++++ llm_server/cluster/redis_config_cache.py | 42 +++++++++++++++++ llm_server/cluster/worker.py | 25 +++++++++++ llm_server/config/load.py | 23 ++++++---- llm_server/config/redis_config.py | 3 ++ .../{routes/cache.py => custom_redis.py} | 45 ++++++++++++++++--- llm_server/database/database.py | 2 +- llm_server/helpers.py | 2 +- llm_server/llm/__init__.py | 2 +- llm_server/llm/info.py | 11 ++--- llm_server/llm/llm_backend.py | 2 +- llm_server/llm/oobabooga/ooba_backend.py | 2 +- llm_server/llm/openai/transform.py | 2 +- llm_server/llm/vllm/generate.py | 2 +- llm_server/opts.py | 1 + llm_server/pre_fork.py | 2 +- llm_server/routes/helpers/client.py | 3 +- llm_server/routes/openai/chat_completions.py | 2 +- llm_server/routes/openai/completions.py | 5 +-- llm_server/routes/openai/info.py | 2 +- llm_server/routes/openai/models.py | 2 +- llm_server/routes/openai/simulated.py | 2 +- llm_server/routes/queue.py | 2 +- llm_server/routes/request_handler.py | 2 +- llm_server/routes/stats.py | 2 +- llm_server/routes/v1/generate_stats.py | 2 +- llm_server/routes/v1/generate_stream.py | 1 - llm_server/routes/v1/info.py | 2 +- llm_server/routes/v1/proxy.py | 4 +- llm_server/workers/blocking.py | 2 +- llm_server/workers/main.py | 3 +- llm_server/workers/printer.py | 2 +- llm_server/workers/recent.py | 2 +- requirements.txt | 1 - server.py | 9 ++-- test-cluster.py | 29 ++++++++++++ 40 files changed, 219 insertions(+), 54 deletions(-) create mode 100644 llm_server/cluster/__init__.py create mode 100644 llm_server/cluster/datastore.py create mode 100644 llm_server/cluster/funcs/__init__.py create mode 100644 llm_server/cluster/funcs/backend.py create mode 100644 llm_server/cluster/redis_config_cache.py create mode 100644 llm_server/cluster/worker.py create mode 100644 llm_server/config/redis_config.py rename llm_server/{routes/cache.py => custom_redis.py} (79%) create mode 100644 test-cluster.py diff --git a/daemon.py b/daemon.py index 20ec300..93e8d34 100644 --- a/daemon.py +++ b/daemon.py @@ -1,6 +1,6 @@ import time -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis try: import gevent.monkey diff --git a/llm_server/cluster/__init__.py b/llm_server/cluster/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_server/cluster/datastore.py b/llm_server/cluster/datastore.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_server/cluster/funcs/__init__.py b/llm_server/cluster/funcs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_server/cluster/funcs/backend.py b/llm_server/cluster/funcs/backend.py new file mode 100644 index 0000000..5b7b535 --- /dev/null +++ b/llm_server/cluster/funcs/backend.py @@ -0,0 +1,26 @@ +from llm_server.cluster.redis_config_cache import RedisClusterStore +from llm_server.llm.info import get_running_model + + +def test_backend(backend_url: str): + running_model, err = get_running_model(backend_url) + if not running_model: + return False + return True + + +def get_best_backends(): + cluster_config = RedisClusterStore('cluster_config') + backends = cluster_config.all() + result = {} + for k, v in backends.items(): + b = cluster_config.get_backend(k) + status = b['online'] + priority = b['priority'] + result[k] = {'status': status, 'priority': priority} + online_backends = sorted( + ((url, info) for url, info in backends.items() if info['online']), + key=lambda kv: kv[1]['priority'], + reverse=True + ) + return [url for url, info in online_backends] diff --git a/llm_server/cluster/redis_config_cache.py b/llm_server/cluster/redis_config_cache.py new file mode 100644 index 0000000..00a6a02 --- /dev/null +++ b/llm_server/cluster/redis_config_cache.py @@ -0,0 +1,42 @@ +import pickle + +from llm_server.custom_redis import RedisCustom + + +class RedisClusterStore: + def __init__(self, name: str, **kwargs): + self.name = name + self.config_redis = RedisCustom(name, **kwargs) + + def clear(self): + self.config_redis.flush() + + def load(self, config: dict): + for k, v in config.items(): + self.set_backend(k, v) + + def set_backend(self, name: str, values: dict): + self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()}) + self.set_backend_value(name, 'online', False) + + def set_backend_value(self, key: str, name: str, value): + self.config_redis.hset(key, name, pickle.dumps(value)) + + def get_backend(self, name: str): + r = self.config_redis.hgetall(name) + output = {} + for k, v in r.items(): + output[k.decode('utf8')] = pickle.loads(v) + return output + + def all(self): + keys = self.config_redis.keys('*') + if keys: + result = {} + for key in keys: + if key != f'{self.name}:____': + v = self.get_backend(key) + result[key] = v + return result + else: + return {} diff --git a/llm_server/cluster/worker.py b/llm_server/cluster/worker.py new file mode 100644 index 0000000..4aaaf6a --- /dev/null +++ b/llm_server/cluster/worker.py @@ -0,0 +1,25 @@ +import time +from threading import Thread + +from llm_server.cluster.funcs.backend import test_backend +from llm_server.cluster.redis_config_cache import RedisClusterStore + +cluster_config = RedisClusterStore('cluster_config') + + +def cluster_worker(): + while True: + threads = [] + for n, v in cluster_config.all().items(): + thread = Thread(target=check_backend, args=(n, v)) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + time.sleep(10) + + +def check_backend(n, v): + # Check if backends are online + online = test_backend(v['backend_url']) + cluster_config.set_backend_value(n, 'online', online) diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 64469b2..82afe81 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -5,22 +5,17 @@ import openai from llm_server import opts from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars +from llm_server.custom_redis import redis from llm_server.database.conn import database from llm_server.database.database import get_number_of_rows -from llm_server.helpers import resolve_path -from llm_server.routes.cache import redis -def load_config(config_path, script_path): +def load_config(config_path): config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) success, config, msg = config_loader.load_config() if not success: return success, config, msg - # Resolve relative directory to the directory of the script - if config['database_path'].startswith('./'): - config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) - if config['mode'] not in ['oobabooga', 'vllm']: print('Unknown mode:', config['mode']) sys.exit(1) @@ -34,7 +29,7 @@ def load_config(config_path, script_path): opts.context_size = config['token_limit'] opts.show_num_prompts = config['show_num_prompts'] opts.show_uptime = config['show_uptime'] - opts.backend_url = config['backend_url'].strip('/') + opts.cluster = config['cluster'] opts.show_total_output_tokens = config['show_total_output_tokens'] opts.netdata_root = config['netdata_root'] opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] @@ -81,3 +76,15 @@ def load_config(config_path, script_path): redis.set('backend_mode', opts.mode) return success, config, msg + + +def parse_backends(config): + if not config.get('cluster'): + return False + cluster = config.get('cluster') + config = {} + for item in cluster: + backend_url = item['backend_url'].strip('/') + item['backend_url'] = backend_url + config[backend_url] = item + return config diff --git a/llm_server/config/redis_config.py b/llm_server/config/redis_config.py new file mode 100644 index 0000000..06ab1d3 --- /dev/null +++ b/llm_server/config/redis_config.py @@ -0,0 +1,3 @@ +from llm_server.custom_redis import RedisCustom + +redis_config = RedisCustom('redis_config') diff --git a/llm_server/routes/cache.py b/llm_server/custom_redis.py similarity index 79% rename from llm_server/routes/cache.py rename to llm_server/custom_redis.py index d7046db..1b0b5f3 100644 --- a/llm_server/routes/cache.py +++ b/llm_server/custom_redis.py @@ -1,19 +1,20 @@ +import pickle import sys import traceback -from typing import Callable, List, Mapping, Union +from typing import Callable, List, Mapping, Union, Optional import redis as redis_pkg import simplejson as json from flask_caching import Cache from redis import Redis -from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT +from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT, PatternT flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) ONE_MONTH_SECONDS = 2678000 -class RedisWrapper: +class RedisCustom: """ A wrapper class to set prefixes to keys. """ @@ -40,7 +41,6 @@ class RedisWrapper: :param dtype: convert to this type :return: """ - d = self.redis.get(self._key(key)) if dtype and d: try: @@ -129,9 +129,35 @@ class RedisWrapper: ): return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt) + def hset( + self, + name: str, + key: Optional = None, + value=None, + mapping: Optional[dict] = None, + items: Optional[list] = None, + ): + return self.redis.hset(self._key(name), key, value, mapping, items) + def hkeys(self, name: str): return self.redis.hkeys(self._key(name)) + def hmget(self, name: str, keys: List, *args: List): + return self.redis.hmget(self._key(name), keys, *args) + + def hgetall(self, name: str): + return self.redis.hgetall(self._key(name)) + + def keys(self, pattern: PatternT = "*", **kwargs): + raw_keys = self.redis.keys(self._key(pattern), **kwargs) + keys = [] + for key in raw_keys: + p = key.decode('utf-8').split(':') + if len(p) > 2: + del p[0] + keys.append(':'.join(p)) + return keys + def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): return self.set(key, json.dumps(dict_value), ex=ex) @@ -142,6 +168,15 @@ class RedisWrapper: else: return json.loads(r.decode("utf-8")) + def setp(self, name, value): + self.redis.set(name, pickle.dumps(value)) + + def getp(self, name: str): + r = self.redis.get(name) + if r: + return pickle.load(r) + return r + def flush(self): flushed = [] for key in self.redis.scan_iter(f'{self.prefix}:*'): @@ -150,4 +185,4 @@ class RedisWrapper: return flushed -redis = RedisWrapper('local_llm') +redis = RedisCustom('local_llm') diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 9bfe578..3779c83 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -6,7 +6,7 @@ import llm_server from llm_server import opts from llm_server.database.conn import database from llm_server.llm.vllm import tokenize -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): diff --git a/llm_server/helpers.py b/llm_server/helpers.py index 44b436b..d6eb7d9 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -8,7 +8,7 @@ import simplejson as json from flask import make_response from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def resolve_path(*p: str): diff --git a/llm_server/llm/__init__.py b/llm_server/llm/__init__.py index 742b1a5..a08b25e 100644 --- a/llm_server/llm/__init__.py +++ b/llm_server/llm/__init__.py @@ -1,5 +1,5 @@ from llm_server.llm import oobabooga, vllm -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def get_token_count(prompt: str): diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py index 5a529ba..bedf3eb 100644 --- a/llm_server/llm/info.py +++ b/llm_server/llm/info.py @@ -3,20 +3,21 @@ import requests from llm_server import opts -def get_running_model(): - # TODO: cache the results for 1 min so we don't have to keep calling the backend - # TODO: only use one try/catch +def get_running_model(backend_url: str): + # TODO: remove this once we go to Redis + if not backend_url: + backend_url = opts.backend_url if opts.mode == 'oobabooga': try: - backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) + backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() return r_json['result'], None except Exception as e: return False, e elif opts.mode == 'vllm': try: - backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) + backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() return r_json['model'], None except Exception as e: diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index 1c11c17..153f66d 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -4,7 +4,7 @@ import flask from llm_server import opts from llm_server.llm import get_token_count -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis class LLMBackend: diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index 4336756..78f2190 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -3,7 +3,7 @@ from flask import jsonify from ..llm_backend import LLMBackend from ...database.database import log_prompt from ...helpers import safe_list_get -from ...routes.cache import redis +from llm_server.custom_redis import redis from ...routes.helpers.client import format_sillytavern_err from ...routes.helpers.http import validate_json diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index d5b64e3..8f1898e 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -12,7 +12,7 @@ from flask import jsonify, make_response import llm_server from llm_server import opts from llm_server.llm import get_token_count -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line. diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 1549f2e..308b1de 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -9,7 +9,7 @@ import requests import llm_server from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis # TODO: make the VLMM backend return TPS and time elapsed diff --git a/llm_server/opts.py b/llm_server/opts.py index de23c7a..5eec1fa 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -37,3 +37,4 @@ openai_moderation_workers = 10 openai_org_name = 'OpenAI' openai_silent_trim = False openai_moderation_enabled = True +cluster = {} diff --git a/llm_server/pre_fork.py b/llm_server/pre_fork.py index f3ea0f4..21da08e 100644 --- a/llm_server/pre_fork.py +++ b/llm_server/pre_fork.py @@ -2,7 +2,7 @@ import sys from redis import Redis -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis from llm_server.routes.v1.generate_stats import generate_stats diff --git a/llm_server/routes/helpers/client.py b/llm_server/routes/helpers/client.py index 48e721e..d97e9c5 100644 --- a/llm_server/routes/helpers/client.py +++ b/llm_server/routes/helpers/client.py @@ -1,5 +1,4 @@ -from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def format_sillytavern_err(msg: str, level: str = 'info'): diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index cc27dce..a289c78 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -6,7 +6,7 @@ import traceback from flask import Response, jsonify, request from . import openai_bp -from ..cache import redis +from llm_server.custom_redis import redis from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ... import opts diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 503f628..05ac7a5 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -4,13 +4,12 @@ import traceback from flask import jsonify, make_response, request from . import openai_bp -from ..cache import redis -from ..helpers.client import format_sillytavern_err +from llm_server.custom_redis import redis from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler from ... import opts from ...llm import get_token_count -from ...llm.openai.transform import build_openai_response, generate_oai_string +from ...llm.openai.transform import generate_oai_string # TODO: add rate-limit headers? diff --git a/llm_server/routes/openai/info.py b/llm_server/routes/openai/info.py index 54959ae..4fc578a 100644 --- a/llm_server/routes/openai/info.py +++ b/llm_server/routes/openai/info.py @@ -1,7 +1,7 @@ from flask import Response from . import openai_bp -from ..cache import flask_cache +from llm_server.custom_redis import flask_cache from ... import opts diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py index 47223e7..4f732e6 100644 --- a/llm_server/routes/openai/models.py +++ b/llm_server/routes/openai/models.py @@ -4,7 +4,7 @@ import requests from flask import jsonify from . import openai_bp -from ..cache import ONE_MONTH_SECONDS, flask_cache, redis +from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis from ..stats import server_start_time from ... import opts from ...helpers import jsonify_pretty diff --git a/llm_server/routes/openai/simulated.py b/llm_server/routes/openai/simulated.py index f626490..301e8de 100644 --- a/llm_server/routes/openai/simulated.py +++ b/llm_server/routes/openai/simulated.py @@ -1,7 +1,7 @@ from flask import jsonify from . import openai_bp -from ..cache import ONE_MONTH_SECONDS, flask_cache +from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache from ...llm.openai.transform import generate_oai_string from ..stats import server_start_time diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 84cc614..8d85319 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -6,7 +6,7 @@ from uuid import uuid4 from redis import Redis from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def increment_ip_count(client_ip: str, redis_key): diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 4b1f640..bb64859 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -11,7 +11,7 @@ from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index a6e9e17..a0846d9 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -1,6 +1,6 @@ from datetime import datetime -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis # proompters_5_min = 0 # concurrent_semaphore = Semaphore(concurrent_gens) diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index e144099..b2dd527 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -6,7 +6,7 @@ from llm_server.database.database import get_distinct_ips_24h, sum_column from llm_server.helpers import deep_sort, round_up_base from llm_server.llm.info import get_running_model from llm_server.netdata import get_power_states -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis from llm_server.routes.queue import priority_queue from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 0fc8f40..45fbf12 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -6,7 +6,6 @@ from typing import Union from flask import request -from ..cache import redis from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index 2091118..7cdbf0f 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -4,7 +4,7 @@ from flask import jsonify, request from . import bp from ..auth import requires_auth -from ..cache import flask_cache +from llm_server.custom_redis import flask_cache from ... import opts from ...llm.info import get_running_model diff --git a/llm_server/routes/v1/proxy.py b/llm_server/routes/v1/proxy.py index 4349526..5ffd194 100644 --- a/llm_server/routes/v1/proxy.py +++ b/llm_server/routes/v1/proxy.py @@ -1,8 +1,6 @@ -from flask import jsonify - from . import bp from .generate_stats import generate_stats -from ..cache import flask_cache +from llm_server.custom_redis import flask_cache from ...helpers import jsonify_pretty diff --git a/llm_server/workers/blocking.py b/llm_server/workers/blocking.py index 27b0815..dcf0047 100644 --- a/llm_server/workers/blocking.py +++ b/llm_server/workers/blocking.py @@ -3,7 +3,7 @@ import time from llm_server import opts from llm_server.llm.generator import generator -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue diff --git a/llm_server/workers/main.py b/llm_server/workers/main.py index 747f699..f592c5e 100644 --- a/llm_server/workers/main.py +++ b/llm_server/workers/main.py @@ -1,10 +1,9 @@ import time -from threading import Thread from llm_server import opts from llm_server.database.database import weighted_average_column_for_model from llm_server.llm.info import get_running_model -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def main_background_thread(): diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py index cb0f032..6a33835 100644 --- a/llm_server/workers/printer.py +++ b/llm_server/workers/printer.py @@ -1,7 +1,7 @@ import logging import time -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis from llm_server.routes.queue import priority_queue logger = logging.getLogger('console_printer') diff --git a/llm_server/workers/recent.py b/llm_server/workers/recent.py index d378a87..c6158d6 100644 --- a/llm_server/workers/recent.py +++ b/llm_server/workers/recent.py @@ -1,6 +1,6 @@ import time -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def recent_prompters_thread(): diff --git a/requirements.txt b/requirements.txt index 9b0c8eb..7b49eed 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,3 @@ websockets~=11.0.3 basicauth~=1.0.0 openai~=0.28.0 urllib3~=2.0.4 -celery[redis] diff --git a/server.py b/server.py index 06482d4..3c334bc 100644 --- a/server.py +++ b/server.py @@ -8,7 +8,7 @@ except ImportError: pass from llm_server.pre_fork import server_startup -from llm_server.config.load import load_config +from llm_server.config.load import load_config, parse_backends import os import sys from pathlib import Path @@ -36,6 +36,7 @@ from llm_server.stream import init_socketio # TODO: add a way to cancel VLLM gens. Maybe use websockets? # TODO: use coloredlogs # TODO: need to update opts. for workers +# TODO: add a healthcheck to VLLM # Lower priority # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens @@ -63,7 +64,7 @@ import config from llm_server import opts from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info -from llm_server.routes.cache import RedisWrapper, flask_cache +from llm_server.custom_redis import RedisCustom, flask_cache from llm_server.llm import redis from llm_server.routes.stats import get_active_gen_workers from llm_server.routes.v1.generate_stats import generate_stats @@ -89,9 +90,11 @@ if not success: database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) create_db() -llm_server.llm.redis = RedisWrapper('local_llm') +llm_server.llm.redis = RedisCustom('local_llm') create_db() +x = parse_backends(config) +print(x) # print(app.url_map) diff --git a/test-cluster.py b/test-cluster.py new file mode 100644 index 0000000..531892b --- /dev/null +++ b/test-cluster.py @@ -0,0 +1,29 @@ +try: + import gevent.monkey + + gevent.monkey.patch_all() +except ImportError: + pass + +import time +from threading import Thread + +from llm_server.cluster.funcs.backend import get_best_backends +from llm_server.cluster.redis_config_cache import RedisClusterStore +from llm_server.cluster.worker import cluster_worker +from llm_server.config.load import parse_backends, load_config + +success, config, msg = load_config('./config/config.yml').resolve().absolute() + +cluster_config = RedisClusterStore('cluster_config') +cluster_config.clear() +cluster_config.load(parse_backends(config)) + +t = Thread(target=cluster_worker) +t.daemon = True +t.start() + +while True: + x = get_best_backends() + print(x) + time.sleep(3)