set up cluster config and basic background workers

This commit is contained in:
Cyberes 2023-09-28 18:40:24 -06:00
parent 89e9f42663
commit e7b57cad7b
40 changed files with 219 additions and 54 deletions

View File

@ -1,6 +1,6 @@
import time
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
try:
import gevent.monkey

View File

View File

View File

View File

@ -0,0 +1,26 @@
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.llm.info import get_running_model
def test_backend(backend_url: str):
running_model, err = get_running_model(backend_url)
if not running_model:
return False
return True
def get_best_backends():
cluster_config = RedisClusterStore('cluster_config')
backends = cluster_config.all()
result = {}
for k, v in backends.items():
b = cluster_config.get_backend(k)
status = b['online']
priority = b['priority']
result[k] = {'status': status, 'priority': priority}
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: kv[1]['priority'],
reverse=True
)
return [url for url, info in online_backends]

View File

@ -0,0 +1,42 @@
import pickle
from llm_server.custom_redis import RedisCustom
class RedisClusterStore:
def __init__(self, name: str, **kwargs):
self.name = name
self.config_redis = RedisCustom(name, **kwargs)
def clear(self):
self.config_redis.flush()
def load(self, config: dict):
for k, v in config.items():
self.set_backend(k, v)
def set_backend(self, name: str, values: dict):
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
self.set_backend_value(name, 'online', False)
def set_backend_value(self, key: str, name: str, value):
self.config_redis.hset(key, name, pickle.dumps(value))
def get_backend(self, name: str):
r = self.config_redis.hgetall(name)
output = {}
for k, v in r.items():
output[k.decode('utf8')] = pickle.loads(v)
return output
def all(self):
keys = self.config_redis.keys('*')
if keys:
result = {}
for key in keys:
if key != f'{self.name}:____':
v = self.get_backend(key)
result[key] = v
return result
else:
return {}

View File

@ -0,0 +1,25 @@
import time
from threading import Thread
from llm_server.cluster.funcs.backend import test_backend
from llm_server.cluster.redis_config_cache import RedisClusterStore
cluster_config = RedisClusterStore('cluster_config')
def cluster_worker():
while True:
threads = []
for n, v in cluster_config.all().items():
thread = Thread(target=check_backend, args=(n, v))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
time.sleep(10)
def check_backend(n, v):
# Check if backends are online
online = test_backend(v['backend_url'])
cluster_config.set_backend_value(n, 'online', online)

View File

@ -5,22 +5,17 @@ import openai
from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
from llm_server.custom_redis import redis
from llm_server.database.conn import database
from llm_server.database.database import get_number_of_rows
from llm_server.helpers import resolve_path
from llm_server.routes.cache import redis
def load_config(config_path, script_path):
def load_config(config_path):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
success, config, msg = config_loader.load_config()
if not success:
return success, config, msg
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
@ -34,7 +29,7 @@ def load_config(config_path, script_path):
opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.cluster = config['cluster']
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
@ -81,3 +76,15 @@ def load_config(config_path, script_path):
redis.set('backend_mode', opts.mode)
return success, config, msg
def parse_backends(config):
if not config.get('cluster'):
return False
cluster = config.get('cluster')
config = {}
for item in cluster:
backend_url = item['backend_url'].strip('/')
item['backend_url'] = backend_url
config[backend_url] = item
return config

View File

@ -0,0 +1,3 @@
from llm_server.custom_redis import RedisCustom
redis_config = RedisCustom('redis_config')

View File

@ -1,19 +1,20 @@
import pickle
import sys
import traceback
from typing import Callable, List, Mapping, Union
from typing import Callable, List, Mapping, Union, Optional
import redis as redis_pkg
import simplejson as json
from flask_caching import Cache
from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT, PatternT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000
class RedisWrapper:
class RedisCustom:
"""
A wrapper class to set prefixes to keys.
"""
@ -40,7 +41,6 @@ class RedisWrapper:
:param dtype: convert to this type
:return:
"""
d = self.redis.get(self._key(key))
if dtype and d:
try:
@ -129,9 +129,35 @@ class RedisWrapper:
):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
def hset(
self,
name: str,
key: Optional = None,
value=None,
mapping: Optional[dict] = None,
items: Optional[list] = None,
):
return self.redis.hset(self._key(name), key, value, mapping, items)
def hkeys(self, name: str):
return self.redis.hkeys(self._key(name))
def hmget(self, name: str, keys: List, *args: List):
return self.redis.hmget(self._key(name), keys, *args)
def hgetall(self, name: str):
return self.redis.hgetall(self._key(name))
def keys(self, pattern: PatternT = "*", **kwargs):
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
keys = []
for key in raw_keys:
p = key.decode('utf-8').split(':')
if len(p) > 2:
del p[0]
keys.append(':'.join(p))
return keys
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex)
@ -142,6 +168,15 @@ class RedisWrapper:
else:
return json.loads(r.decode("utf-8"))
def setp(self, name, value):
self.redis.set(name, pickle.dumps(value))
def getp(self, name: str):
r = self.redis.get(name)
if r:
return pickle.load(r)
return r
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):
@ -150,4 +185,4 @@ class RedisWrapper:
return flushed
redis = RedisWrapper('local_llm')
redis = RedisCustom('local_llm')

View File

@ -6,7 +6,7 @@ import llm_server
from llm_server import opts
from llm_server.database.conn import database
from llm_server.llm.vllm import tokenize
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):

View File

@ -8,7 +8,7 @@ import simplejson as json
from flask import make_response
from llm_server import opts
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def resolve_path(*p: str):

View File

@ -1,5 +1,5 @@
from llm_server.llm import oobabooga, vllm
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def get_token_count(prompt: str):

View File

@ -3,20 +3,21 @@ import requests
from llm_server import opts
def get_running_model():
# TODO: cache the results for 1 min so we don't have to keep calling the backend
# TODO: only use one try/catch
def get_running_model(backend_url: str):
# TODO: remove this once we go to Redis
if not backend_url:
backend_url = opts.backend_url
if opts.mode == 'oobabooga':
try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json()
return r_json['result'], None
except Exception as e:
return False, e
elif opts.mode == 'vllm':
try:
backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json()
return r_json['model'], None
except Exception as e:

View File

@ -4,7 +4,7 @@ import flask
from llm_server import opts
from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
class LLMBackend:

View File

@ -3,7 +3,7 @@ from flask import jsonify
from ..llm_backend import LLMBackend
from ...database.database import log_prompt
from ...helpers import safe_list_get
from ...routes.cache import redis
from llm_server.custom_redis import redis
from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json

View File

@ -12,7 +12,7 @@ from flask import jsonify, make_response
import llm_server
from llm_server import opts
from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.

View File

@ -9,7 +9,7 @@ import requests
import llm_server
from llm_server import opts
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
# TODO: make the VLMM backend return TPS and time elapsed

View File

@ -37,3 +37,4 @@ openai_moderation_workers = 10
openai_org_name = 'OpenAI'
openai_silent_trim = False
openai_moderation_enabled = True
cluster = {}

View File

@ -2,7 +2,7 @@ import sys
from redis import Redis
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
from llm_server.routes.v1.generate_stats import generate_stats

View File

@ -1,5 +1,4 @@
from llm_server import opts
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def format_sillytavern_err(msg: str, level: str = 'info'):

View File

@ -6,7 +6,7 @@ import traceback
from flask import Response, jsonify, request
from . import openai_bp
from ..cache import redis
from llm_server.custom_redis import redis
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler
from ... import opts

View File

@ -4,13 +4,12 @@ import traceback
from flask import jsonify, make_response, request
from . import openai_bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from llm_server.custom_redis import redis
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ... import opts
from ...llm import get_token_count
from ...llm.openai.transform import build_openai_response, generate_oai_string
from ...llm.openai.transform import generate_oai_string
# TODO: add rate-limit headers?

View File

@ -1,7 +1,7 @@
from flask import Response
from . import openai_bp
from ..cache import flask_cache
from llm_server.custom_redis import flask_cache
from ... import opts

View File

@ -4,7 +4,7 @@ import requests
from flask import jsonify
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, flask_cache, redis
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from ..stats import server_start_time
from ... import opts
from ...helpers import jsonify_pretty

View File

@ -1,7 +1,7 @@
from flask import jsonify
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, flask_cache
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache
from ...llm.openai.transform import generate_oai_string
from ..stats import server_start_time

View File

@ -6,7 +6,7 @@ from uuid import uuid4
from redis import Redis
from llm_server import opts
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def increment_ip_count(client_ip: str, redis_key):

View File

@ -11,7 +11,7 @@ from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue

View File

@ -1,6 +1,6 @@
from datetime import datetime
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
# proompters_5_min = 0
# concurrent_semaphore = Semaphore(concurrent_gens)

View File

@ -6,7 +6,7 @@ from llm_server.database.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort, round_up_base
from llm_server.llm.info import get_running_model
from llm_server.netdata import get_power_states
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time

View File

@ -6,7 +6,6 @@ from typing import Union
from flask import request
from ..cache import redis
from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue

View File

@ -4,7 +4,7 @@ from flask import jsonify, request
from . import bp
from ..auth import requires_auth
from ..cache import flask_cache
from llm_server.custom_redis import flask_cache
from ... import opts
from ...llm.info import get_running_model

View File

@ -1,8 +1,6 @@
from flask import jsonify
from . import bp
from .generate_stats import generate_stats
from ..cache import flask_cache
from llm_server.custom_redis import flask_cache
from ...helpers import jsonify_pretty

View File

@ -3,7 +3,7 @@ import time
from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue

View File

@ -1,10 +1,9 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def main_background_thread():

View File

@ -1,7 +1,7 @@
import logging
import time
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
from llm_server.routes.queue import priority_queue
logger = logging.getLogger('console_printer')

View File

@ -1,6 +1,6 @@
import time
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def recent_prompters_thread():

View File

@ -18,4 +18,3 @@ websockets~=11.0.3
basicauth~=1.0.0
openai~=0.28.0
urllib3~=2.0.4
celery[redis]

View File

@ -8,7 +8,7 @@ except ImportError:
pass
from llm_server.pre_fork import server_startup
from llm_server.config.load import load_config
from llm_server.config.load import load_config, parse_backends
import os
import sys
from pathlib import Path
@ -36,6 +36,7 @@ from llm_server.stream import init_socketio
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: use coloredlogs
# TODO: need to update opts. for workers
# TODO: add a healthcheck to VLLM
# Lower priority
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
@ -63,7 +64,7 @@ import config
from llm_server import opts
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import RedisWrapper, flask_cache
from llm_server.custom_redis import RedisCustom, flask_cache
from llm_server.llm import redis
from llm_server.routes.stats import get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats
@ -89,9 +90,11 @@ if not success:
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db()
llm_server.llm.redis = RedisWrapper('local_llm')
llm_server.llm.redis = RedisCustom('local_llm')
create_db()
x = parse_backends(config)
print(x)
# print(app.url_map)

29
test-cluster.py Normal file
View File

@ -0,0 +1,29 @@
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import time
from threading import Thread
from llm_server.cluster.funcs.backend import get_best_backends
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.cluster.worker import cluster_worker
from llm_server.config.load import parse_backends, load_config
success, config, msg = load_config('./config/config.yml').resolve().absolute()
cluster_config = RedisClusterStore('cluster_config')
cluster_config.clear()
cluster_config.load(parse_backends(config))
t = Thread(target=cluster_worker)
t.daemon = True
t.start()
while True:
x = get_best_backends()
print(x)
time.sleep(3)