set up cluster config and basic background workers

This commit is contained in:
Cyberes 2023-09-28 18:40:24 -06:00
parent 89e9f42663
commit e7b57cad7b
40 changed files with 219 additions and 54 deletions

View File

@ -1,6 +1,6 @@
import time import time
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
try: try:
import gevent.monkey import gevent.monkey

View File

View File

View File

View File

@ -0,0 +1,26 @@
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.llm.info import get_running_model
def test_backend(backend_url: str):
running_model, err = get_running_model(backend_url)
if not running_model:
return False
return True
def get_best_backends():
cluster_config = RedisClusterStore('cluster_config')
backends = cluster_config.all()
result = {}
for k, v in backends.items():
b = cluster_config.get_backend(k)
status = b['online']
priority = b['priority']
result[k] = {'status': status, 'priority': priority}
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: kv[1]['priority'],
reverse=True
)
return [url for url, info in online_backends]

View File

@ -0,0 +1,42 @@
import pickle
from llm_server.custom_redis import RedisCustom
class RedisClusterStore:
def __init__(self, name: str, **kwargs):
self.name = name
self.config_redis = RedisCustom(name, **kwargs)
def clear(self):
self.config_redis.flush()
def load(self, config: dict):
for k, v in config.items():
self.set_backend(k, v)
def set_backend(self, name: str, values: dict):
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
self.set_backend_value(name, 'online', False)
def set_backend_value(self, key: str, name: str, value):
self.config_redis.hset(key, name, pickle.dumps(value))
def get_backend(self, name: str):
r = self.config_redis.hgetall(name)
output = {}
for k, v in r.items():
output[k.decode('utf8')] = pickle.loads(v)
return output
def all(self):
keys = self.config_redis.keys('*')
if keys:
result = {}
for key in keys:
if key != f'{self.name}:____':
v = self.get_backend(key)
result[key] = v
return result
else:
return {}

View File

@ -0,0 +1,25 @@
import time
from threading import Thread
from llm_server.cluster.funcs.backend import test_backend
from llm_server.cluster.redis_config_cache import RedisClusterStore
cluster_config = RedisClusterStore('cluster_config')
def cluster_worker():
while True:
threads = []
for n, v in cluster_config.all().items():
thread = Thread(target=check_backend, args=(n, v))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
time.sleep(10)
def check_backend(n, v):
# Check if backends are online
online = test_backend(v['backend_url'])
cluster_config.set_backend_value(n, 'online', online)

View File

@ -5,22 +5,17 @@ import openai
from llm_server import opts from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
from llm_server.custom_redis import redis
from llm_server.database.conn import database from llm_server.database.conn import database
from llm_server.database.database import get_number_of_rows from llm_server.database.database import get_number_of_rows
from llm_server.helpers import resolve_path
from llm_server.routes.cache import redis
def load_config(config_path, script_path): def load_config(config_path):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
success, config, msg = config_loader.load_config() success, config, msg = config_loader.load_config()
if not success: if not success:
return success, config, msg return success, config, msg
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
if config['mode'] not in ['oobabooga', 'vllm']: if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode']) print('Unknown mode:', config['mode'])
sys.exit(1) sys.exit(1)
@ -34,7 +29,7 @@ def load_config(config_path, script_path):
opts.context_size = config['token_limit'] opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts'] opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime'] opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/') opts.cluster = config['cluster']
opts.show_total_output_tokens = config['show_total_output_tokens'] opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root'] opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
@ -81,3 +76,15 @@ def load_config(config_path, script_path):
redis.set('backend_mode', opts.mode) redis.set('backend_mode', opts.mode)
return success, config, msg return success, config, msg
def parse_backends(config):
if not config.get('cluster'):
return False
cluster = config.get('cluster')
config = {}
for item in cluster:
backend_url = item['backend_url'].strip('/')
item['backend_url'] = backend_url
config[backend_url] = item
return config

View File

@ -0,0 +1,3 @@
from llm_server.custom_redis import RedisCustom
redis_config = RedisCustom('redis_config')

View File

@ -1,19 +1,20 @@
import pickle
import sys import sys
import traceback import traceback
from typing import Callable, List, Mapping, Union from typing import Callable, List, Mapping, Union, Optional
import redis as redis_pkg import redis as redis_pkg
import simplejson as json import simplejson as json
from flask_caching import Cache from flask_caching import Cache
from redis import Redis from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT, PatternT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000 ONE_MONTH_SECONDS = 2678000
class RedisWrapper: class RedisCustom:
""" """
A wrapper class to set prefixes to keys. A wrapper class to set prefixes to keys.
""" """
@ -40,7 +41,6 @@ class RedisWrapper:
:param dtype: convert to this type :param dtype: convert to this type
:return: :return:
""" """
d = self.redis.get(self._key(key)) d = self.redis.get(self._key(key))
if dtype and d: if dtype and d:
try: try:
@ -129,9 +129,35 @@ class RedisWrapper:
): ):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt) return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
def hset(
self,
name: str,
key: Optional = None,
value=None,
mapping: Optional[dict] = None,
items: Optional[list] = None,
):
return self.redis.hset(self._key(name), key, value, mapping, items)
def hkeys(self, name: str): def hkeys(self, name: str):
return self.redis.hkeys(self._key(name)) return self.redis.hkeys(self._key(name))
def hmget(self, name: str, keys: List, *args: List):
return self.redis.hmget(self._key(name), keys, *args)
def hgetall(self, name: str):
return self.redis.hgetall(self._key(name))
def keys(self, pattern: PatternT = "*", **kwargs):
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
keys = []
for key in raw_keys:
p = key.decode('utf-8').split(':')
if len(p) > 2:
del p[0]
keys.append(':'.join(p))
return keys
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex) return self.set(key, json.dumps(dict_value), ex=ex)
@ -142,6 +168,15 @@ class RedisWrapper:
else: else:
return json.loads(r.decode("utf-8")) return json.loads(r.decode("utf-8"))
def setp(self, name, value):
self.redis.set(name, pickle.dumps(value))
def getp(self, name: str):
r = self.redis.get(name)
if r:
return pickle.load(r)
return r
def flush(self): def flush(self):
flushed = [] flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'): for key in self.redis.scan_iter(f'{self.prefix}:*'):
@ -150,4 +185,4 @@ class RedisWrapper:
return flushed return flushed
redis = RedisWrapper('local_llm') redis = RedisCustom('local_llm')

View File

@ -6,7 +6,7 @@ import llm_server
from llm_server import opts from llm_server import opts
from llm_server.database.conn import database from llm_server.database.conn import database
from llm_server.llm.vllm import tokenize from llm_server.llm.vllm import tokenize
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):

View File

@ -8,7 +8,7 @@ import simplejson as json
from flask import make_response from flask import make_response
from llm_server import opts from llm_server import opts
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
def resolve_path(*p: str): def resolve_path(*p: str):

View File

@ -1,5 +1,5 @@
from llm_server.llm import oobabooga, vllm from llm_server.llm import oobabooga, vllm
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
def get_token_count(prompt: str): def get_token_count(prompt: str):

View File

@ -3,20 +3,21 @@ import requests
from llm_server import opts from llm_server import opts
def get_running_model(): def get_running_model(backend_url: str):
# TODO: cache the results for 1 min so we don't have to keep calling the backend # TODO: remove this once we go to Redis
# TODO: only use one try/catch if not backend_url:
backend_url = opts.backend_url
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':
try: try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json() r_json = backend_response.json()
return r_json['result'], None return r_json['result'], None
except Exception as e: except Exception as e:
return False, e return False, e
elif opts.mode == 'vllm': elif opts.mode == 'vllm':
try: try:
backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json() r_json = backend_response.json()
return r_json['model'], None return r_json['model'], None
except Exception as e: except Exception as e:

View File

@ -4,7 +4,7 @@ import flask
from llm_server import opts from llm_server import opts
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
class LLMBackend: class LLMBackend:

View File

@ -3,7 +3,7 @@ from flask import jsonify
from ..llm_backend import LLMBackend from ..llm_backend import LLMBackend
from ...database.database import log_prompt from ...database.database import log_prompt
from ...helpers import safe_list_get from ...helpers import safe_list_get
from ...routes.cache import redis from llm_server.custom_redis import redis
from ...routes.helpers.client import format_sillytavern_err from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json from ...routes.helpers.http import validate_json

View File

@ -12,7 +12,7 @@ from flask import jsonify, make_response
import llm_server import llm_server
from llm_server import opts from llm_server import opts
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line. ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.

View File

@ -9,7 +9,7 @@ import requests
import llm_server import llm_server
from llm_server import opts from llm_server import opts
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
# TODO: make the VLMM backend return TPS and time elapsed # TODO: make the VLMM backend return TPS and time elapsed

View File

@ -37,3 +37,4 @@ openai_moderation_workers = 10
openai_org_name = 'OpenAI' openai_org_name = 'OpenAI'
openai_silent_trim = False openai_silent_trim = False
openai_moderation_enabled = True openai_moderation_enabled = True
cluster = {}

View File

@ -2,7 +2,7 @@ import sys
from redis import Redis from redis import Redis
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats

View File

@ -1,5 +1,4 @@
from llm_server import opts from llm_server.custom_redis import redis
from llm_server.routes.cache import redis
def format_sillytavern_err(msg: str, level: str = 'info'): def format_sillytavern_err(msg: str, level: str = 'info'):

View File

@ -6,7 +6,7 @@ import traceback
from flask import Response, jsonify, request from flask import Response, jsonify, request
from . import openai_bp from . import openai_bp
from ..cache import redis from llm_server.custom_redis import redis
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler from ..openai_request_handler import OpenAIRequestHandler
from ... import opts from ... import opts

View File

@ -4,13 +4,12 @@ import traceback
from flask import jsonify, make_response, request from flask import jsonify, make_response, request
from . import openai_bp from . import openai_bp
from ..cache import redis from llm_server.custom_redis import redis
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ... import opts from ... import opts
from ...llm import get_token_count from ...llm import get_token_count
from ...llm.openai.transform import build_openai_response, generate_oai_string from ...llm.openai.transform import generate_oai_string
# TODO: add rate-limit headers? # TODO: add rate-limit headers?

View File

@ -1,7 +1,7 @@
from flask import Response from flask import Response
from . import openai_bp from . import openai_bp
from ..cache import flask_cache from llm_server.custom_redis import flask_cache
from ... import opts from ... import opts

View File

@ -4,7 +4,7 @@ import requests
from flask import jsonify from flask import jsonify
from . import openai_bp from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, flask_cache, redis from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from ..stats import server_start_time from ..stats import server_start_time
from ... import opts from ... import opts
from ...helpers import jsonify_pretty from ...helpers import jsonify_pretty

View File

@ -1,7 +1,7 @@
from flask import jsonify from flask import jsonify
from . import openai_bp from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, flask_cache from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache
from ...llm.openai.transform import generate_oai_string from ...llm.openai.transform import generate_oai_string
from ..stats import server_start_time from ..stats import server_start_time

View File

@ -6,7 +6,7 @@ from uuid import uuid4
from redis import Redis from redis import Redis
from llm_server import opts from llm_server import opts
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
def increment_ip_count(client_ip: str, redis_key): def increment_ip_count(client_ip: str, redis_key):

View File

@ -11,7 +11,7 @@ from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token from llm_server.routes.auth import parse_token
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue

View File

@ -1,6 +1,6 @@
from datetime import datetime from datetime import datetime
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
# proompters_5_min = 0 # proompters_5_min = 0
# concurrent_semaphore = Semaphore(concurrent_gens) # concurrent_semaphore = Semaphore(concurrent_gens)

View File

@ -6,7 +6,7 @@ from llm_server.database.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort, round_up_base from llm_server.helpers import deep_sort, round_up_base
from llm_server.llm.info import get_running_model from llm_server.llm.info import get_running_model
from llm_server.netdata import get_power_states from llm_server.netdata import get_power_states
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time

View File

@ -6,7 +6,6 @@ from typing import Union
from flask import request from flask import request
from ..cache import redis
from ..helpers.http import require_api_key, validate_json from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ..queue import decr_active_workers, decrement_ip_count, priority_queue

View File

@ -4,7 +4,7 @@ from flask import jsonify, request
from . import bp from . import bp
from ..auth import requires_auth from ..auth import requires_auth
from ..cache import flask_cache from llm_server.custom_redis import flask_cache
from ... import opts from ... import opts
from ...llm.info import get_running_model from ...llm.info import get_running_model

View File

@ -1,8 +1,6 @@
from flask import jsonify
from . import bp from . import bp
from .generate_stats import generate_stats from .generate_stats import generate_stats
from ..cache import flask_cache from llm_server.custom_redis import flask_cache
from ...helpers import jsonify_pretty from ...helpers import jsonify_pretty

View File

@ -3,7 +3,7 @@ import time
from llm_server import opts from llm_server import opts
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue

View File

@ -1,10 +1,9 @@
import time import time
from threading import Thread
from llm_server import opts from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
def main_background_thread(): def main_background_thread():

View File

@ -1,7 +1,7 @@
import logging import logging
import time import time
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
logger = logging.getLogger('console_printer') logger = logging.getLogger('console_printer')

View File

@ -1,6 +1,6 @@
import time import time
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
def recent_prompters_thread(): def recent_prompters_thread():

View File

@ -18,4 +18,3 @@ websockets~=11.0.3
basicauth~=1.0.0 basicauth~=1.0.0
openai~=0.28.0 openai~=0.28.0
urllib3~=2.0.4 urllib3~=2.0.4
celery[redis]

View File

@ -8,7 +8,7 @@ except ImportError:
pass pass
from llm_server.pre_fork import server_startup from llm_server.pre_fork import server_startup
from llm_server.config.load import load_config from llm_server.config.load import load_config, parse_backends
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
@ -36,6 +36,7 @@ from llm_server.stream import init_socketio
# TODO: add a way to cancel VLLM gens. Maybe use websockets? # TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: use coloredlogs # TODO: use coloredlogs
# TODO: need to update opts. for workers # TODO: need to update opts. for workers
# TODO: add a healthcheck to VLLM
# Lower priority # Lower priority
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
@ -63,7 +64,7 @@ import config
from llm_server import opts from llm_server import opts
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import RedisWrapper, flask_cache from llm_server.custom_redis import RedisCustom, flask_cache
from llm_server.llm import redis from llm_server.llm import redis
from llm_server.routes.stats import get_active_gen_workers from llm_server.routes.stats import get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
@ -89,9 +90,11 @@ if not success:
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db() create_db()
llm_server.llm.redis = RedisWrapper('local_llm') llm_server.llm.redis = RedisCustom('local_llm')
create_db() create_db()
x = parse_backends(config)
print(x)
# print(app.url_map) # print(app.url_map)

29
test-cluster.py Normal file
View File

@ -0,0 +1,29 @@
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import time
from threading import Thread
from llm_server.cluster.funcs.backend import get_best_backends
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.cluster.worker import cluster_worker
from llm_server.config.load import parse_backends, load_config
success, config, msg = load_config('./config/config.yml').resolve().absolute()
cluster_config = RedisClusterStore('cluster_config')
cluster_config.clear()
cluster_config.load(parse_backends(config))
t = Thread(target=cluster_worker)
t.daemon = True
t.start()
while True:
x = get_best_backends()
print(x)
time.sleep(3)