From 624ca74ce5f31bc81292c6507eb7830081d3e625 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Fri, 29 Sep 2023 00:09:44 -0600 Subject: [PATCH] mvp --- daemon.py | 18 +--- llm_server/cluster/backend.py | 71 ++++++++++++ llm_server/cluster/cluster_config.py | 3 + llm_server/cluster/datastore.py | 0 llm_server/cluster/funcs/__init__.py | 0 llm_server/cluster/funcs/backend.py | 26 ----- llm_server/cluster/redis_config_cache.py | 12 ++- llm_server/cluster/redis_cycle.py | 21 ++++ llm_server/cluster/stores.py | 3 + llm_server/cluster/worker.py | 20 ++-- llm_server/config/config.py | 3 +- llm_server/config/load.py | 1 + llm_server/custom_redis.py | 34 ++++-- llm_server/database/database.py | 91 +++++++++------- llm_server/helpers.py | 2 +- llm_server/integer.py | 12 --- llm_server/llm/__init__.py | 2 +- llm_server/llm/generator.py | 4 +- llm_server/llm/info.py | 10 +- llm_server/llm/llm_backend.py | 2 +- llm_server/llm/openai/transform.py | 4 +- llm_server/llm/vllm/generate.py | 12 +-- llm_server/llm/vllm/vllm_backend.py | 12 +-- llm_server/opts.py | 2 +- llm_server/pre_fork.py | 2 +- llm_server/routes/helpers/client.py | 8 +- llm_server/routes/ooba_request_handler.py | 4 +- llm_server/routes/openai/chat_completions.py | 10 +- llm_server/routes/openai/completions.py | 4 +- llm_server/routes/openai/models.py | 3 +- llm_server/routes/queue.py | 2 +- llm_server/routes/request_handler.py | 22 ++-- llm_server/routes/stats.py | 23 ---- llm_server/routes/v1/generate_stats.py | 101 +++++++++++------- llm_server/routes/v1/generate_stream.py | 25 ++--- llm_server/routes/v1/info.py | 40 +++---- llm_server/{stream.py => sock.py} | 0 llm_server/workers/app.py | 35 ------ .../workers/{blocking.py => inferencer.py} | 8 +- llm_server/workers/main.py | 55 ---------- llm_server/workers/mainer.py | 56 ++++++++++ llm_server/workers/{recent.py => recenter.py} | 0 llm_server/workers/threader.py | 50 +++++++++ llm_server/workers/threads.py | 9 -- gunicorn.py => other/gunicorn.py | 5 + server.py | 49 +++++---- test-cluster.py | 20 +++- 47 files changed, 506 insertions(+), 390 deletions(-) create mode 100644 llm_server/cluster/backend.py create mode 100644 llm_server/cluster/cluster_config.py delete mode 100644 llm_server/cluster/datastore.py delete mode 100644 llm_server/cluster/funcs/__init__.py delete mode 100644 llm_server/cluster/funcs/backend.py create mode 100644 llm_server/cluster/redis_cycle.py create mode 100644 llm_server/cluster/stores.py delete mode 100644 llm_server/integer.py rename llm_server/{stream.py => sock.py} (100%) delete mode 100644 llm_server/workers/app.py rename llm_server/workers/{blocking.py => inferencer.py} (88%) delete mode 100644 llm_server/workers/main.py create mode 100644 llm_server/workers/mainer.py rename llm_server/workers/{recent.py => recenter.py} (100%) create mode 100644 llm_server/workers/threader.py delete mode 100644 llm_server/workers/threads.py rename gunicorn.py => other/gunicorn.py (60%) diff --git a/daemon.py b/daemon.py index 93e8d34..82635f0 100644 --- a/daemon.py +++ b/daemon.py @@ -1,22 +1,12 @@ -import time - -from llm_server.custom_redis import redis - -try: - import gevent.monkey - - gevent.monkey.patch_all() -except ImportError: - pass - import os import sys +import time from pathlib import Path from llm_server.config.load import load_config +from llm_server.custom_redis import redis from llm_server.database.create import create_db - -from llm_server.workers.app import start_background +from llm_server.workers.threader import start_background script_path = os.path.dirname(os.path.realpath(__file__)) config_path_environ = os.getenv("CONFIG_PATH") @@ -29,7 +19,7 @@ if __name__ == "__main__": flushed_keys = redis.flush() print('Flushed', len(flushed_keys), 'keys from Redis.') - success, config, msg = load_config(config_path, script_path) + success, config, msg = load_config(config_path) if not success: print('Failed to load config:', msg) sys.exit(1) diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py new file mode 100644 index 0000000..7b28e86 --- /dev/null +++ b/llm_server/cluster/backend.py @@ -0,0 +1,71 @@ +from llm_server.cluster.redis_config_cache import RedisClusterStore +from llm_server.cluster.redis_cycle import redis_cycle +from llm_server.cluster.stores import redis_running_models +from llm_server.llm.info import get_running_model + + +def test_backend(backend_url: str, mode: str): + running_model, err = get_running_model(backend_url, mode) + if not running_model: + return False + return True + + +def get_backends(): + cluster_config = RedisClusterStore('cluster_config') + backends = cluster_config.all() + result = {} + for k, v in backends.items(): + b = cluster_config.get_backend(k) + status = b['online'] + priority = b['priority'] + result[k] = {'status': status, 'priority': priority} + online_backends = sorted( + ((url, info) for url, info in backends.items() if info['online']), + key=lambda kv: -kv[1]['priority'], + reverse=True + ) + offline_backends = sorted( + ((url, info) for url, info in backends.items() if not info['online']), + key=lambda kv: -kv[1]['priority'], + reverse=True + ) + return [url for url, info in online_backends], [url for url, info in offline_backends] + + +def get_a_cluster_backend(): + """ + Get a backend from Redis. If there are no online backends, return None. + """ + online, offline = get_backends() + cycled = redis_cycle('backend_cycler') + c = cycled.copy() + for i in range(len(cycled)): + if cycled[i] in offline: + del c[c.index(cycled[i])] + if len(c): + return c[0] + else: + return None + + +def get_backends_from_model(model_name: str): + cluster_config = RedisClusterStore('cluster_config') + a = cluster_config.all() + matches = [] + for k, v in a.items(): + if v['online'] and v['running_model'] == model_name: + matches.append(k) + return matches + + +def purge_backend_from_running_models(backend_url: str): + keys = redis_running_models.keys() + pipeline = redis_running_models.pipeline() + for model in keys: + pipeline.srem(model, backend_url) + pipeline.execute() + + +def is_valid_model(model_name: str): + return redis_running_models.exists(model_name) diff --git a/llm_server/cluster/cluster_config.py b/llm_server/cluster/cluster_config.py new file mode 100644 index 0000000..14a6cb0 --- /dev/null +++ b/llm_server/cluster/cluster_config.py @@ -0,0 +1,3 @@ +from llm_server.cluster.redis_config_cache import RedisClusterStore + +cluster_config = RedisClusterStore('cluster_config') diff --git a/llm_server/cluster/datastore.py b/llm_server/cluster/datastore.py deleted file mode 100644 index e69de29..0000000 diff --git a/llm_server/cluster/funcs/__init__.py b/llm_server/cluster/funcs/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/llm_server/cluster/funcs/backend.py b/llm_server/cluster/funcs/backend.py deleted file mode 100644 index 5b7b535..0000000 --- a/llm_server/cluster/funcs/backend.py +++ /dev/null @@ -1,26 +0,0 @@ -from llm_server.cluster.redis_config_cache import RedisClusterStore -from llm_server.llm.info import get_running_model - - -def test_backend(backend_url: str): - running_model, err = get_running_model(backend_url) - if not running_model: - return False - return True - - -def get_best_backends(): - cluster_config = RedisClusterStore('cluster_config') - backends = cluster_config.all() - result = {} - for k, v in backends.items(): - b = cluster_config.get_backend(k) - status = b['online'] - priority = b['priority'] - result[k] = {'status': status, 'priority': priority} - online_backends = sorted( - ((url, info) for url, info in backends.items() if info['online']), - key=lambda kv: kv[1]['priority'], - reverse=True - ) - return [url for url, info in online_backends] diff --git a/llm_server/cluster/redis_config_cache.py b/llm_server/cluster/redis_config_cache.py index 00a6a02..ebb6099 100644 --- a/llm_server/cluster/redis_config_cache.py +++ b/llm_server/cluster/redis_config_cache.py @@ -1,3 +1,4 @@ +import hashlib import pickle from llm_server.custom_redis import RedisCustom @@ -13,14 +14,17 @@ class RedisClusterStore: def load(self, config: dict): for k, v in config.items(): - self.set_backend(k, v) + self.add_backend(k, v) - def set_backend(self, name: str, values: dict): + def add_backend(self, name: str, values: dict): self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()}) self.set_backend_value(name, 'online', False) + h = hashlib.sha256(name.encode('utf-8')).hexdigest() + self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}') - def set_backend_value(self, key: str, name: str, value): - self.config_redis.hset(key, name, pickle.dumps(value)) + def set_backend_value(self, backend: str, key: str, value): + # By storing the value as a pickle we don't have to cast anything when getting the value from Redis. + self.config_redis.hset(backend, key, pickle.dumps(value)) def get_backend(self, name: str): r = self.config_redis.hgetall(name) diff --git a/llm_server/cluster/redis_cycle.py b/llm_server/cluster/redis_cycle.py new file mode 100644 index 0000000..87893ba --- /dev/null +++ b/llm_server/cluster/redis_cycle.py @@ -0,0 +1,21 @@ +import redis + +r = redis.Redis(host='localhost', port=6379, db=9) + + +def redis_cycle(list_name): + while True: + pipe = r.pipeline() + pipe.lpop(list_name) + popped_element = pipe.execute()[0] + if popped_element is None: + return None + r.rpush(list_name, popped_element) + new_list = r.lrange(list_name, 0, -1) + return [x.decode('utf-8') for x in new_list] + + +def load_backend_cycle(list_name: str, elements: list): + r.delete(list_name) + for element in elements: + r.rpush(list_name, element) diff --git a/llm_server/cluster/stores.py b/llm_server/cluster/stores.py new file mode 100644 index 0000000..c0cbdcc --- /dev/null +++ b/llm_server/cluster/stores.py @@ -0,0 +1,3 @@ +from llm_server.custom_redis import RedisCustom + +redis_running_models = RedisCustom('running_models') diff --git a/llm_server/cluster/worker.py b/llm_server/cluster/worker.py index 4aaaf6a..bee280a 100644 --- a/llm_server/cluster/worker.py +++ b/llm_server/cluster/worker.py @@ -1,10 +1,10 @@ -import time +from datetime import datetime from threading import Thread -from llm_server.cluster.funcs.backend import test_backend -from llm_server.cluster.redis_config_cache import RedisClusterStore - -cluster_config = RedisClusterStore('cluster_config') +from llm_server.cluster.backend import purge_backend_from_running_models, test_backend +from llm_server.cluster.cluster_config import cluster_config +from llm_server.cluster.stores import redis_running_models +from llm_server.llm.info import get_running_model def cluster_worker(): @@ -16,10 +16,16 @@ def cluster_worker(): threads.append(thread) for thread in threads: thread.join() - time.sleep(10) def check_backend(n, v): # Check if backends are online - online = test_backend(v['backend_url']) + # TODO: also have test_backend() get the uptime + online = test_backend(v['backend_url'], v['mode']) + if online: + running_model, err = get_running_model(v['backend_url'], v['mode']) + if not err: + cluster_config.set_backend_value(n, 'running_model', running_model) + purge_backend_from_running_models(n) + redis_running_models.sadd(running_model, n) cluster_config.set_backend_value(n, 'online', online) diff --git a/llm_server/config/config.py b/llm_server/config/config.py index 59568d7..b98ea49 100644 --- a/llm_server/config/config.py +++ b/llm_server/config/config.py @@ -32,7 +32,8 @@ config_default_vars = { 'openai_org_name': 'OpenAI', 'openai_silent_trim': False, 'openai_moderation_enabled': True, - 'netdata_root': None + 'netdata_root': None, + 'show_backends': True, } config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 82afe81..09fb127 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -52,6 +52,7 @@ def load_config(config_path): opts.openai_org_name = config['openai_org_name'] opts.openai_silent_trim = config['openai_silent_trim'] opts.openai_moderation_enabled = config['openai_moderation_enabled'] + opts.show_backends = config['show_backends'] if opts.openai_expose_our_model and not opts.openai_api_key: print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') diff --git a/llm_server/custom_redis.py b/llm_server/custom_redis.py index 1b0b5f3..b0db49e 100644 --- a/llm_server/custom_redis.py +++ b/llm_server/custom_redis.py @@ -1,13 +1,13 @@ import pickle import sys import traceback -from typing import Callable, List, Mapping, Union, Optional +from typing import Callable, List, Mapping, Optional, Union import redis as redis_pkg import simplejson as json from flask_caching import Cache from redis import Redis -from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT, PatternT +from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) @@ -35,12 +35,12 @@ class RedisCustom: def set(self, key, value, ex: Union[ExpiryT, None] = None): return self.redis.set(self._key(key), value, ex=ex) - def get(self, key, dtype=None, default=None): - """ - :param key: - :param dtype: convert to this type - :return: - """ + def get(self, key, default=None, dtype=None): + # TODO: use pickle + import inspect + if inspect.isclass(default): + raise Exception + d = self.redis.get(self._key(key)) if dtype and d: try: @@ -153,11 +153,23 @@ class RedisCustom: keys = [] for key in raw_keys: p = key.decode('utf-8').split(':') - if len(p) > 2: + if len(p) >= 2: + # Delete prefix del p[0] - keys.append(':'.join(p)) + k = ':'.join(p) + if k != '____': + keys.append(k) return keys + def pipeline(self, transaction=True, shard_hint=None): + return self.redis.pipeline(transaction, shard_hint) + + def exists(self, *names: KeyT): + n = [] + for name in names: + n.append(self._key(name)) + return self.redis.exists(*n) + def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): return self.set(key, json.dumps(dict_value), ex=ex) @@ -174,7 +186,7 @@ class RedisCustom: def getp(self, name: str): r = self.redis.get(name) if r: - return pickle.load(r) + return pickle.loads(r) return r def flush(self): diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 3779c83..bf5f537 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -1,60 +1,69 @@ import json import time import traceback +from threading import Thread import llm_server from llm_server import opts +from llm_server.custom_redis import redis from llm_server.database.conn import database from llm_server.llm.vllm import tokenize -from llm_server.custom_redis import redis -def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): - if isinstance(response, dict) and response.get('results'): - response = response['results'][0]['text'] - try: - j = json.loads(response) - if j.get('results'): - response = j['results'][0]['text'] - except: - pass +def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens: int = None, is_error: bool = False): + def background_task(): + nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens, is_error + # Try not to shove JSON into the database. + if isinstance(response, dict) and response.get('results'): + response = response['results'][0]['text'] + try: + j = json.loads(response) + if j.get('results'): + response = j['results'][0]['text'] + except: + pass - prompt_tokens = llm_server.llm.get_token_count(prompt) - if not is_error: - if not response_tokens: - response_tokens = llm_server.llm.get_token_count(response) - else: - response_tokens = None + prompt_tokens = llm_server.llm.get_token_count(prompt) + if not is_error: + if not response_tokens: + response_tokens = llm_server.llm.get_token_count(response) + else: + response_tokens = None - # Sometimes we may want to insert null into the DB, but - # usually we want to insert a float. - if gen_time: - gen_time = round(gen_time, 3) - if is_error: - gen_time = None + # Sometimes we may want to insert null into the DB, but + # usually we want to insert a float. + if gen_time: + gen_time = round(gen_time, 3) + if is_error: + gen_time = None - if not opts.log_prompts: - prompt = None + if not opts.log_prompts: + prompt = None - if not opts.log_prompts and not is_error: - # TODO: test and verify this works as expected - response = None + if not opts.log_prompts and not is_error: + # TODO: test and verify this works as expected + response = None - if token: - increment_token_uses(token) + if token: + increment_token_uses(token) - running_model = redis.get('running_model', str, 'ERROR') - timestamp = int(time.time()) - cursor = database.cursor() - try: - cursor.execute(""" - INSERT INTO prompts - (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) - VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) - """, - (ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) - finally: - cursor.close() + running_model = redis.get('running_model', str, 'ERROR') + timestamp = int(time.time()) + cursor = database.cursor() + try: + cursor.execute(""" + INSERT INTO prompts + (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) + """, + (ip, token, running_model, opts.mode, cluster_backend, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) + finally: + cursor.close() + + # TODO: use async/await instead of threads + thread = Thread(target=background_task) + thread.start() + thread.join() def is_valid_api_key(api_key): diff --git a/llm_server/helpers.py b/llm_server/helpers.py index d6eb7d9..9fc7274 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -60,7 +60,7 @@ def round_up_base(n, base): def auto_set_base_client_api(request): - http_host = redis.get('http_host', str) + http_host = redis.get('http_host', dtype=str) host = request.headers.get("Host") if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host): # If the current http_host is not an IP, don't do anything. diff --git a/llm_server/integer.py b/llm_server/integer.py deleted file mode 100644 index 1410dd1..0000000 --- a/llm_server/integer.py +++ /dev/null @@ -1,12 +0,0 @@ -import threading - - -class ThreadSafeInteger: - def __init__(self, value=0): - self.value = value - self._value_lock = threading.Lock() - - def increment(self): - with self._value_lock: - self.value += 1 - return self.value diff --git a/llm_server/llm/__init__.py b/llm_server/llm/__init__.py index a08b25e..6e39b42 100644 --- a/llm_server/llm/__init__.py +++ b/llm_server/llm/__init__.py @@ -3,7 +3,7 @@ from llm_server.custom_redis import redis def get_token_count(prompt: str): - backend_mode = redis.get('backend_mode', str) + backend_mode = redis.get('backend_mode', dtype=str) if backend_mode == 'vllm': return vllm.tokenize(prompt) elif backend_mode == 'ooba': diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py index 5dd2093..42c3bb7 100644 --- a/llm_server/llm/generator.py +++ b/llm_server/llm/generator.py @@ -1,14 +1,14 @@ from llm_server import opts -def generator(request_json_body): +def generator(request_json_body, cluster_backend): if opts.mode == 'oobabooga': # from .oobabooga.generate import generate # return generate(request_json_body) raise NotImplementedError elif opts.mode == 'vllm': from .vllm.generate import generate - r = generate(request_json_body) + r = generate(request_json_body, cluster_backend) return r else: raise Exception diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py index bedf3eb..117da3f 100644 --- a/llm_server/llm/info.py +++ b/llm_server/llm/info.py @@ -3,19 +3,15 @@ import requests from llm_server import opts -def get_running_model(backend_url: str): - # TODO: remove this once we go to Redis - if not backend_url: - backend_url = opts.backend_url - - if opts.mode == 'oobabooga': +def get_running_model(backend_url: str, mode: str): + if mode == 'ooba': try: backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() return r_json['result'], None except Exception as e: return False, e - elif opts.mode == 'vllm': + elif mode == 'vllm': try: backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index 153f66d..e8268b1 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -40,6 +40,6 @@ class LLMBackend: def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: prompt_len = get_token_count(prompt) if prompt_len > opts.context_size - 10: - model_name = redis.get('running_model', str, 'NO MODEL ERROR') + model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str) return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size' return True, None diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index 8f1898e..62e0ed8 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -34,7 +34,7 @@ def build_openai_response(prompt, response, model=None): # TODO: async/await prompt_tokens = llm_server.llm.get_token_count(prompt) response_tokens = llm_server.llm.get_token_count(response) - running_model = redis.get('running_model', str, 'ERROR') + running_model = redis.get('running_model', 'ERROR', dtype=str) response = make_response(jsonify({ "id": f"chatcmpl-{generate_oai_string(30)}", @@ -57,7 +57,7 @@ def build_openai_response(prompt, response, model=None): } }), 200) - stats = redis.get('proxy_stats', dict) + stats = redis.get('proxy_stats', dtype=dict) if stats: response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] return response diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 308b1de..caac445 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -49,7 +49,7 @@ def transform_to_text(json_request, api_response): prompt_tokens = len(llm_server.llm.get_token_count(prompt)) completion_tokens = len(llm_server.llm.get_token_count(text)) - running_model = redis.get('running_model', str, 'ERROR') + running_model = redis.get('running_model', 'ERROR', dtype=str) # https://platform.openai.com/docs/api-reference/making-requests?lang=python return { @@ -82,9 +82,9 @@ def transform_prompt_to_text(prompt: list): return text.strip('\n') -def handle_blocking_request(json_data: dict): +def handle_blocking_request(json_data: dict, cluster_backend): try: - r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) except requests.exceptions.ReadTimeout: print(f'Failed to reach VLLM inference endpoint - request to backend timed out') return False, None, 'Request to backend timed out' @@ -97,11 +97,11 @@ def handle_blocking_request(json_data: dict): return True, r, None -def generate(json_data: dict): +def generate(json_data: dict, cluster_backend): if json_data.get('stream'): try: - return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) except Exception as e: print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') else: - return handle_blocking_request(json_data) + return handle_blocking_request(json_data, cluster_backend) diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index e5b0fad..3db99d9 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -19,16 +19,8 @@ class VLLMBackend(LLMBackend): # Failsafe backend_response = '' - r_url = request.url - - def background_task(): - log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url, - response_tokens=response_json_body.get('details', {}).get('generated_tokens')) - - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task) - thread.start() - thread.join() + log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, + response_tokens=response_json_body.get('details', {}).get('generated_tokens')) return jsonify({'results': [{'text': backend_response}]}), 200 diff --git a/llm_server/opts.py b/llm_server/opts.py index 5eec1fa..0d13979 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -2,7 +2,6 @@ # TODO: rewrite the config system so I don't have to add every single config default here -running_model = 'ERROR' concurrent_gens = 3 mode = 'oobabooga' backend_url = None @@ -38,3 +37,4 @@ openai_org_name = 'OpenAI' openai_silent_trim = False openai_moderation_enabled = True cluster = {} +show_backends = True diff --git a/llm_server/pre_fork.py b/llm_server/pre_fork.py index 21da08e..900210c 100644 --- a/llm_server/pre_fork.py +++ b/llm_server/pre_fork.py @@ -7,7 +7,7 @@ from llm_server.routes.v1.generate_stats import generate_stats def server_startup(s): - if not redis.get('daemon_started', bool): + if not redis.get('daemon_started', dtype=bool): print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') sys.exit(1) diff --git a/llm_server/routes/helpers/client.py b/llm_server/routes/helpers/client.py index d97e9c5..a914362 100644 --- a/llm_server/routes/helpers/client.py +++ b/llm_server/routes/helpers/client.py @@ -1,10 +1,14 @@ +from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import redis -def format_sillytavern_err(msg: str, level: str = 'info'): - http_host = redis.get('http_host', str) +def format_sillytavern_err(msg: str, backend_url: str, level: str = 'info'): + cluster_backend_hash = cluster_config.get_backend_handler(backend_url)['hash'] + http_host = redis.get('http_host', dtype=str) return f"""``` === MESSAGE FROM LLM MIDDLEWARE AT {http_host} === -> {level.upper()} <- {msg} + +BACKEND HASH: {cluster_backend_hash} ```""" diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index d6b02e2..8e0036c 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler): msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' backend_response = self.handle_error(msg) if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.cluster_backend, is_error=True) return backend_response[0], 200 # We only return the response from handle_error(), not the error code def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: @@ -40,7 +40,7 @@ class OobaRequestHandler(RequestHandler): # TODO: how to format this response_msg = error_msg else: - response_msg = format_sillytavern_err(error_msg, error_type) + response_msg = format_sillytavern_err(error_msg, error_type, self.cluster_backend) return jsonify({ 'results': [{'text': response_msg}] diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index a289c78..b3159a5 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -5,11 +5,12 @@ import traceback from flask import Response, jsonify, request -from . import openai_bp from llm_server.custom_redis import redis +from . import openai_bp from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler from ... import opts +from ...cluster.backend import get_a_cluster_backend from ...database.database import log_prompt from ...llm.generator import generator from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt @@ -48,10 +49,11 @@ def openai_chat_completions(): 'stream': True, } try: - response = generator(msg_to_backend) + cluster_backend = get_a_cluster_backend() + response = generator(msg_to_backend, cluster_backend) r_headers = dict(request.headers) r_url = request.url - model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model') + model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) def generate(): @@ -94,7 +96,7 @@ def openai_chat_completions(): def background_task(): generated_tokens = tokenize(generated_text) - log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) + log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=generated_tokens) # TODO: use async/await instead of threads thread = threading.Thread(target=background_task) diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 05ac7a5..8950927 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -29,7 +29,7 @@ def openai_completions(): # TODO: async/await prompt_tokens = get_token_count(request_json_body['prompt']) response_tokens = get_token_count(output) - running_model = redis.get('running_model', str, 'ERROR') + running_model = redis.get('running_model', 'ERROR', dtype=str) response = make_response(jsonify({ "id": f"cmpl-{generate_oai_string(30)}", @@ -51,7 +51,7 @@ def openai_completions(): } }), 200) - stats = redis.get('proxy_stats', dict) + stats = redis.get('proxy_stats', dtype=dict) if stats: response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] return response diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py index 4f732e6..657f084 100644 --- a/llm_server/routes/openai/models.py +++ b/llm_server/routes/openai/models.py @@ -7,6 +7,7 @@ from . import openai_bp from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis from ..stats import server_start_time from ... import opts +from ...cluster.backend import get_a_cluster_backend from ...helpers import jsonify_pretty from ...llm.info import get_running_model @@ -22,7 +23,7 @@ def openai_list_models(): 'type': error.__class__.__name__ }), 500 # return 500 so Cloudflare doesn't intercept us else: - running_model = redis.get('running_model', str, 'ERROR') + running_model = redis.get('running_model', 'ERROR', dtype=str) oai = fetch_openai_models() r = [] if opts.openai_expose_our_model: diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 8d85319..09ed06c 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -93,6 +93,6 @@ def incr_active_workers(): def decr_active_workers(): redis.decr('active_gen_workers') - new_count = redis.get('active_gen_workers', int, 0) + new_count = redis.get('active_gen_workers', 0, dtype=int) if new_count < 0: redis.set('active_gen_workers', 0) diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index bb64859..ecae085 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -5,13 +5,15 @@ import flask from flask import Response, request from llm_server import opts +from llm_server.cluster.backend import get_a_cluster_backend +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis from llm_server.database.conn import database from llm_server.database.database import log_prompt from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token -from llm_server.custom_redis import redis from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue @@ -35,7 +37,9 @@ class RequestHandler: self.client_ip = self.get_client_ip() self.token = self.get_auth_token() self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit() - self.backend = get_backend() + self.cluster_backend = get_a_cluster_backend() + self.cluster_backend_info = cluster_config.get_backend(self.cluster_backend) + self.backend = get_backend_handler(self.cluster_backend) self.parameters = None self.used = False redis.zadd('recent_prompters', {self.client_ip: time.time()}) @@ -119,7 +123,7 @@ class RequestHandler: backend_response = self.handle_error(combined_error_message, 'Validation Error') if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) + log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.cluster_backend, is_error=True) return False, backend_response return True, (None, 0) @@ -131,7 +135,7 @@ class RequestHandler: request_valid, invalid_response = self.validate_request(prompt, do_log=True) if not request_valid: return (False, None, None, 0), invalid_response - event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority) + event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.cluster_backend), self.token_priority) else: event = None @@ -160,7 +164,7 @@ class RequestHandler: else: error_msg = error_msg.strip('.') + '.' backend_response = self.handle_error(error_msg) - log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) + log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True) return (False, None, None, 0), backend_response # =============================================== @@ -180,7 +184,7 @@ class RequestHandler: if return_json_err: error_msg = 'The backend did not return valid JSON.' backend_response = self.handle_error(error_msg) - log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) + log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True) return (False, None, None, 0), backend_response # =============================================== @@ -214,10 +218,10 @@ class RequestHandler: raise NotImplementedError -def get_backend(): - if opts.mode == 'oobabooga': +def get_backend_handler(mode): + if mode == 'oobabooga': return OobaboogaBackend() - elif opts.mode == 'vllm': + elif mode == 'vllm': return VLLMBackend() else: raise Exception diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index a0846d9..b4dea54 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -2,32 +2,9 @@ from datetime import datetime from llm_server.custom_redis import redis -# proompters_5_min = 0 -# concurrent_semaphore = Semaphore(concurrent_gens) - server_start_time = datetime.now() -# TODO: do I need this? -# def elapsed_times_cleanup(): -# global wait_in_queue_elapsed -# while True: -# current_time = time.time() -# with wait_in_queue_elapsed_lock: -# global wait_in_queue_elapsed -# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60] -# time.sleep(1) - - -def calculate_avg_gen_time(): - # Get the average generation time from Redis - average_generation_time = redis.get('average_generation_time') - if average_generation_time is None: - return 0 - else: - return float(average_generation_time) - - def get_total_proompts(): count = redis.get('proompts') if count is None: diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index b2dd527..66dd316 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -2,11 +2,12 @@ import time from datetime import datetime from llm_server import opts +from llm_server.cluster.backend import get_a_cluster_backend, test_backend +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis from llm_server.database.database import get_distinct_ips_24h, sum_column from llm_server.helpers import deep_sort, round_up_base from llm_server.llm.info import get_running_model -from llm_server.netdata import get_power_states -from llm_server.custom_redis import redis from llm_server.routes.queue import priority_queue from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time @@ -33,52 +34,43 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act return gen_time_calc -# TODO: have routes/__init__.py point to the latest API version generate_stats() - def generate_stats(regen: bool = False): if not regen: - c = redis.get('proxy_stats', dict) + c = redis.get('proxy_stats', dtype=dict) if c: return c - model_name, error = get_running_model() # will return False when the fetch fails - if isinstance(model_name, bool): - online = False - else: - online = True - redis.set('running_model', model_name) + default_backend_url = get_a_cluster_backend() + default_backend_info = cluster_config.get_backend(default_backend_url) + if not default_backend_info.get('mode'): + # TODO: remove + print('DAEMON NOT FINISHED STARTING') + return + base_client_api = redis.get('base_client_api', dtype=str) + proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf')) + average_generation_elapsed_sec = redis.get('average_generation_elapsed_sec', 0) - # t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change - # if len(t) == 0: - # estimated_wait = 0 - # else: - # waits = [elapsed for end, elapsed in t] - # estimated_wait = int(sum(waits) / len(waits)) + online = test_backend(default_backend_url, default_backend_info['mode']) + if online: + running_model, err = get_running_model(default_backend_url, default_backend_info['mode']) + cluster_config.set_backend_value(default_backend_url, 'running_model', running_model) + else: + running_model = None active_gen_workers = get_active_gen_workers() proompters_in_queue = len(priority_queue) - # This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM. + # This is so wildly inaccurate it's disabled. # estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0) - average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0) - estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers) - - if opts.netdata_root: - netdata_stats = {} - power_states = get_power_states() - for gpu, power_state in power_states.items(): - netdata_stats[gpu] = { - 'power_state': power_state, - # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu'))) - } - else: - netdata_stats = {} - - base_client_api = redis.get('base_client_api', str) - proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf')) + # TODO: make this for the currently selected backend + estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers) output = { + 'default': { + 'model': running_model, + 'backend': default_backend_info['hash'], + }, 'stats': { 'proompters': { '5_min': proompters_5_min, @@ -86,9 +78,10 @@ def generate_stats(regen: bool = False): }, 'proompts_total': get_total_proompts() if opts.show_num_prompts else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None, - 'average_generation_elapsed_sec': int(average_generation_time), + 'average_generation_elapsed_sec': int(average_generation_elapsed_sec), # 'estimated_avg_tps': estimated_avg_tps, 'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None, + 'num_backends': len(cluster_config.all()) if opts.show_backends else None, }, 'online': online, 'endpoints': { @@ -103,10 +96,7 @@ def generate_stats(regen: bool = False): 'timestamp': int(time.time()), 'config': { 'gatekeeper': 'none' if opts.auth_required is False else 'token', - 'context_size': opts.context_size, 'concurrent': opts.concurrent_gens, - 'model': opts.manual_model_name if opts.manual_model_name else model_name, - 'mode': opts.mode, 'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip, }, 'keys': { @@ -114,8 +104,41 @@ def generate_stats(regen: bool = False): 'anthropicKeys': '∞', }, 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, - 'nvidia': netdata_stats } + + if opts.show_backends: + for backend_url, v in cluster_config.all().items(): + backend_info = cluster_config.get_backend(backend_url) + if not backend_info['online']: + continue + + # TODO: have this fetch the data from VLLM which will display GPU utalization + # if opts.netdata_root: + # netdata_stats = {} + # power_states = get_power_states() + # for gpu, power_state in power_states.items(): + # netdata_stats[gpu] = { + # 'power_state': power_state, + # # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu'))) + # } + # else: + # netdata_stats = {} + netdata_stats = {} + + # TODO: use value returned by VLLM backend here + # backend_uptime = int((datetime.now() - backend_info['start_time']).total_seconds()) if opts.show_uptime else None + backend_uptime = -1 + + output['backend_info'][backend_info['hash']] = { + 'uptime': backend_uptime, + # 'context_size': opts.context_size, + 'model': opts.manual_model_name if opts.manual_model_name else backend_info.get('running_model', 'ERROR'), + 'mode': backend_info['mode'], + 'nvidia': netdata_stats + } + else: + output['backend_info'] = {} + result = deep_sort(output) # It may take a bit to get the base client API, so don't cache until then. diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 45fbf12..e3aeeb0 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -1,5 +1,4 @@ import json -import threading import time import traceback from typing import Union @@ -10,10 +9,11 @@ from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ... import opts +from ...cluster.backend import get_a_cluster_backend from ...database.database import log_prompt from ...llm.generator import generator from ...llm.vllm import tokenize -from ...stream import sock +from ...sock import sock # TODO: have workers process streaming requests @@ -35,19 +35,13 @@ def stream(ws): log_in_bg(quitting_err_msg, is_error=True) def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None): - - def background_task_exception(): - generated_tokens = tokenize(generated_text_bg) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error) - - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task_exception) - thread.start() - thread.join() + generated_tokens = tokenize(generated_text_bg) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, cluster_backend, response_tokens=generated_tokens, is_error=is_error) if not opts.enable_streaming: return 'Streaming is disabled', 401 + cluster_backend = None r_headers = dict(request.headers) r_url = request.url message_num = 0 @@ -90,14 +84,15 @@ def stream(ws): } # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority) + event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority) if not event: r, _ = handler.handle_ratelimited() err_msg = r.json['results'][0]['text'] send_err_and_quit(err_msg) return try: - response = generator(llm_request) + cluster_backend = get_a_cluster_backend() + response = generator(llm_request, cluster_backend) if not response: error_msg = 'Failed to reach backend while streaming.' print('Streaming failed:', error_msg) @@ -142,7 +137,7 @@ def stream(ws): ws.close() end_time = time.time() elapsed_time = end_time - start_time - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text)) return message_num += 1 @@ -181,5 +176,5 @@ def stream(ws): # The client closed the stream. end_time = time.time() elapsed_time = end_time - start_time - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) + log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text)) ws.close() # this is important if we encountered and error and exited early. diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index 7cdbf0f..90778e5 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -2,22 +2,21 @@ import time from flask import jsonify, request +from llm_server.custom_redis import flask_cache from . import bp from ..auth import requires_auth -from llm_server.custom_redis import flask_cache from ... import opts -from ...llm.info import get_running_model - - -# @bp.route('/info', methods=['GET']) -# # @cache.cached(timeout=3600, query_string=True) -# def get_info(): -# # requests.get() -# return 'yes' +from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends_from_model, is_valid_model +from ...cluster.cluster_config import cluster_config @bp.route('/model', methods=['GET']) -def get_model(): +@bp.route('//model', methods=['GET']) +def get_model(model_name=None): + if not model_name: + b = get_a_cluster_backend() + model_name = cluster_config.get_backend(b)['running_model'] + # We will manage caching ourself since we don't want to cache # when the backend is down. Also, Cloudflare won't cache 500 errors. cache_key = 'model_cache::' + request.url @@ -26,16 +25,17 @@ def get_model(): if cached_response: return cached_response - model_name, error = get_running_model() - if not model_name: + if not is_valid_model(model_name): response = jsonify({ - 'code': 502, - 'msg': 'failed to reach backend', - 'type': error.__class__.__name__ - }), 500 # return 500 so Cloudflare doesn't intercept us + 'code': 400, + 'msg': 'Model does not exist.', + }), 400 else: + num_backends = len(get_backends_from_model(model_name)) + response = jsonify({ 'result': opts.manual_model_name if opts.manual_model_name else model_name, + 'model_backend_count': num_backends, 'timestamp': int(time.time()) }), 200 flask_cache.set(cache_key, response, timeout=60) @@ -43,7 +43,11 @@ def get_model(): return response -@bp.route('/backend', methods=['GET']) +@bp.route('/backends', methods=['GET']) @requires_auth def get_backend(): - return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200 + online, offline = get_backends() + result = [] + for i in online + offline: + result.append(cluster_config.get_backend(i)) + return jsonify(result), 200 diff --git a/llm_server/stream.py b/llm_server/sock.py similarity index 100% rename from llm_server/stream.py rename to llm_server/sock.py diff --git a/llm_server/workers/app.py b/llm_server/workers/app.py deleted file mode 100644 index fda6fb3..0000000 --- a/llm_server/workers/app.py +++ /dev/null @@ -1,35 +0,0 @@ -from threading import Thread - -from .blocking import start_workers -from .main import main_background_thread -from .moderator import start_moderation_workers -from .printer import console_printer -from .recent import recent_prompters_thread -from .threads import cache_stats -from .. import opts - - -def start_background(): - start_workers(opts.concurrent_gens) - - t = Thread(target=main_background_thread) - t.daemon = True - t.start() - print('Started the main background thread.') - - start_moderation_workers(opts.openai_moderation_workers) - - t = Thread(target=cache_stats) - t.daemon = True - t.start() - print('Started the stats cacher.') - - t = Thread(target=recent_prompters_thread) - t.daemon = True - t.start() - print('Started the recent proompters thread.') - - t = Thread(target=console_printer) - t.daemon = True - t.start() - print('Started the console printer.') diff --git a/llm_server/workers/blocking.py b/llm_server/workers/inferencer.py similarity index 88% rename from llm_server/workers/blocking.py rename to llm_server/workers/inferencer.py index dcf0047..626e34b 100644 --- a/llm_server/workers/blocking.py +++ b/llm_server/workers/inferencer.py @@ -2,15 +2,15 @@ import threading import time from llm_server import opts -from llm_server.llm.generator import generator from llm_server.custom_redis import redis +from llm_server.llm.generator import generator from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue def worker(): while True: need_to_wait() - (request_json_body, client_ip, token, parameters), event_id = priority_queue.get() + (request_json_body, client_ip, token, parameters, cluster_backend), event_id = priority_queue.get() need_to_wait() increment_ip_count(client_ip, 'processing_ips') @@ -22,7 +22,7 @@ def worker(): continue try: - success, response, error_msg = generator(request_json_body) + success, response, error_msg = generator(request_json_body, cluster_backend) event = DataEvent(event_id) event.set((success, response, error_msg)) finally: @@ -42,7 +42,7 @@ def start_workers(num_workers: int): def need_to_wait(): # We need to check the number of active workers since the streaming endpoint may be doing something. - active_workers = redis.get('active_gen_workers', int, 0) + active_workers = redis.get('active_gen_workers', 0, dtype=int) s = time.time() while active_workers >= opts.concurrent_gens: time.sleep(0.01) diff --git a/llm_server/workers/main.py b/llm_server/workers/main.py deleted file mode 100644 index f592c5e..0000000 --- a/llm_server/workers/main.py +++ /dev/null @@ -1,55 +0,0 @@ -import time - -from llm_server import opts -from llm_server.database.database import weighted_average_column_for_model -from llm_server.llm.info import get_running_model -from llm_server.custom_redis import redis - - -def main_background_thread(): - redis.set('average_generation_elapsed_sec', 0) - redis.set('estimated_avg_tps', 0) - redis.set('average_output_tokens', 0) - redis.set('backend_online', 0) - redis.set_dict('backend_info', {}) - - while True: - # TODO: unify this - if opts.mode == 'oobabooga': - running_model, err = get_running_model() - if err: - print(err) - redis.set('backend_online', 0) - else: - redis.set('running_model', running_model) - redis.set('backend_online', 1) - elif opts.mode == 'vllm': - running_model, err = get_running_model() - if err: - print(err) - redis.set('backend_online', 0) - else: - redis.set('running_model', running_model) - redis.set('backend_online', 1) - else: - raise Exception - - # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 - # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. - average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 - if average_generation_elapsed_sec: # returns None on exception - redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) - - # overall = average_column_for_model('prompts', 'generation_time', opts.running_model) - # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') - - average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 - if average_generation_elapsed_sec: - redis.set('average_output_tokens', average_output_tokens) - - # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) - # print(f'Weighted: {average_output_tokens}, overall: {overall}') - - estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero - redis.set('estimated_avg_tps', estimated_avg_tps) - time.sleep(60) diff --git a/llm_server/workers/mainer.py b/llm_server/workers/mainer.py new file mode 100644 index 0000000..447046f --- /dev/null +++ b/llm_server/workers/mainer.py @@ -0,0 +1,56 @@ +import time + +from llm_server import opts +from llm_server.cluster.backend import get_a_cluster_backend, get_backends +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis +from llm_server.database.database import weighted_average_column_for_model +from llm_server.llm.info import get_running_model + + +def main_background_thread(): + while True: + online, offline = get_backends() + for backend_url in online: + backend_info = cluster_config.get_backend(backend_url) + backend_mode = backend_info['mode'] + running_model, err = get_running_model(backend_url, backend_mode) + if err: + continue + + average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode) + if average_generation_elapsed_sec: # returns None on exception + cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec) + if average_output_tokens: + cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens) + if average_generation_elapsed_sec and average_output_tokens: + cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps) + + default_backend_url = get_a_cluster_backend() + default_backend_info = cluster_config.get_backend(default_backend_url) + default_backend_mode = default_backend_info['mode'] + default_running_model, err = get_running_model(default_backend_url, default_backend_mode) + if err: + continue + + default_average_generation_elapsed_sec, default_average_output_tokens, default_estimated_avg_tps = calc_stats_for_backend(default_running_model, default_running_model, default_backend_mode) + if default_average_generation_elapsed_sec: + redis.set('average_generation_elapsed_sec', default_average_generation_elapsed_sec) + if default_average_output_tokens: + redis.set('average_output_tokens', default_average_output_tokens) + if default_average_generation_elapsed_sec and default_average_output_tokens: + redis.set('estimated_avg_tps', default_estimated_avg_tps) + time.sleep(30) + + +def calc_stats_for_backend(backend_url, running_model, backend_mode): + # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 + # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. + average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', + running_model, backend_mode, backend_url, exclude_zeros=True, + include_system_tokens=opts.include_system_tokens_in_stats) or 0 + average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', + running_model, backend_mode, backend_url, exclude_zeros=True, + include_system_tokens=opts.include_system_tokens_in_stats) or 0 + estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero + return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps diff --git a/llm_server/workers/recent.py b/llm_server/workers/recenter.py similarity index 100% rename from llm_server/workers/recent.py rename to llm_server/workers/recenter.py diff --git a/llm_server/workers/threader.py b/llm_server/workers/threader.py new file mode 100644 index 0000000..83bac2d --- /dev/null +++ b/llm_server/workers/threader.py @@ -0,0 +1,50 @@ +import time +from threading import Thread + +from llm_server import opts +from llm_server.cluster.stores import redis_running_models +from llm_server.cluster.worker import cluster_worker +from llm_server.routes.v1.generate_stats import generate_stats +from llm_server.workers.inferencer import start_workers +from llm_server.workers.mainer import main_background_thread +from llm_server.workers.moderator import start_moderation_workers +from llm_server.workers.printer import console_printer +from llm_server.workers.recenter import recent_prompters_thread + + +def cache_stats(): + while True: + generate_stats(regen=True) + time.sleep(1) + + +def start_background(): + start_workers(opts.concurrent_gens) + + t = Thread(target=main_background_thread) + t.daemon = True + t.start() + print('Started the main background thread.') + + start_moderation_workers(opts.openai_moderation_workers) + + t = Thread(target=cache_stats) + t.daemon = True + t.start() + print('Started the stats cacher.') + + t = Thread(target=recent_prompters_thread) + t.daemon = True + t.start() + print('Started the recent proompters thread.') + + t = Thread(target=console_printer) + t.daemon = True + t.start() + print('Started the console printer.') + + redis_running_models.flush() + t = Thread(target=cluster_worker) + t.daemon = True + t.start() + print('Started the cluster worker.') diff --git a/llm_server/workers/threads.py b/llm_server/workers/threads.py deleted file mode 100644 index d1c5183..0000000 --- a/llm_server/workers/threads.py +++ /dev/null @@ -1,9 +0,0 @@ -import time - -from llm_server.routes.v1.generate_stats import generate_stats - - -def cache_stats(): - while True: - generate_stats(regen=True) - time.sleep(5) diff --git a/gunicorn.py b/other/gunicorn.py similarity index 60% rename from gunicorn.py rename to other/gunicorn.py index 30f9274..099e9ce 100644 --- a/gunicorn.py +++ b/other/gunicorn.py @@ -1,3 +1,8 @@ +""" +This file is used to run certain tasks when the HTTP server starts. +It's located here so it doesn't get imported with daemon.py +""" + try: import gevent.monkey diff --git a/server.py b/server.py index 3c334bc..0214b49 100644 --- a/server.py +++ b/server.py @@ -1,4 +1,4 @@ -from llm_server.config.config import mode_ui_names +from llm_server.cluster.cluster_config import cluster_config try: import gevent.monkey @@ -7,8 +7,6 @@ try: except ImportError: pass -from llm_server.pre_fork import server_startup -from llm_server.config.load import load_config, parse_backends import os import sys from pathlib import Path @@ -16,14 +14,17 @@ from pathlib import Path import simplejson as json from flask import Flask, jsonify, render_template, request -import llm_server +from llm_server.cluster.backend import get_a_cluster_backend, get_backends +from llm_server.cluster.redis_cycle import load_backend_cycle +from llm_server.config.config import mode_ui_names +from llm_server.config.load import load_config, parse_backends from llm_server.database.conn import database from llm_server.database.create import create_db -from llm_server.llm import get_token_count +from llm_server.pre_fork import server_startup from llm_server.routes.openai import openai_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp -from llm_server.stream import init_socketio +from llm_server.sock import init_socketio # TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. # TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail @@ -37,6 +38,8 @@ from llm_server.stream import init_socketio # TODO: use coloredlogs # TODO: need to update opts. for workers # TODO: add a healthcheck to VLLM +# TODO: allow choosing the model by the URL path +# TODO: have VLLM report context size, uptime # Lower priority # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens @@ -64,7 +67,7 @@ import config from llm_server import opts from llm_server.helpers import auto_set_base_client_api from llm_server.llm.vllm.info import vllm_info -from llm_server.custom_redis import RedisCustom, flask_cache +from llm_server.custom_redis import flask_cache from llm_server.llm import redis from llm_server.routes.stats import get_active_gen_workers from llm_server.routes.v1.generate_stats import generate_stats @@ -83,20 +86,18 @@ if config_path_environ: else: config_path = Path(script_path, 'config', 'config.yml') -success, config, msg = load_config(config_path, script_path) +success, config, msg = load_config(config_path) if not success: print('Failed to load config:', msg) sys.exit(1) database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) create_db() -llm_server.llm.redis = RedisCustom('local_llm') -create_db() -x = parse_backends(config) -print(x) - -# print(app.url_map) +cluster_config.clear() +cluster_config.load(parse_backends(config)) +on, off = get_backends() +load_backend_cycle('backend_cycler', on + off) @app.route('/') @@ -104,12 +105,18 @@ print(x) @app.route('/api/openai') @flask_cache.cached(timeout=10) def home(): - stats = generate_stats() + # Use the default backend + backend_url = get_a_cluster_backend() + if backend_url: + backend_info = cluster_config.get_backend(backend_url) + stats = generate_stats(backend_url) + else: + backend_info = stats = None if not stats['online']: running_model = estimated_wait_sec = 'offline' else: - running_model = redis.get('running_model', str, 'ERROR') + running_model = backend_info['running_model'] active_gen_workers = get_active_gen_workers() if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: @@ -130,10 +137,16 @@ def home(): info_html = '' mode_info = '' - if opts.mode == 'vllm': + using_vllm = False + for k, v in cluster_config.all().items(): + if v['mode'] == vllm: + using_vllm = True + break + + if using_vllm == 'vllm': mode_info = vllm_info - base_client_api = redis.get('base_client_api', str) + base_client_api = redis.get('base_client_api', dtype=str) return render_template('home.html', llm_middleware_name=opts.llm_middleware_name, diff --git a/test-cluster.py b/test-cluster.py index 531892b..ec1773a 100644 --- a/test-cluster.py +++ b/test-cluster.py @@ -7,23 +7,33 @@ except ImportError: import time from threading import Thread +from llm_server.cluster.redis_cycle import load_backend_cycle -from llm_server.cluster.funcs.backend import get_best_backends -from llm_server.cluster.redis_config_cache import RedisClusterStore +from llm_server.cluster.backend import get_backends, get_a_cluster_backend from llm_server.cluster.worker import cluster_worker from llm_server.config.load import parse_backends, load_config +from llm_server.cluster.redis_config_cache import RedisClusterStore -success, config, msg = load_config('./config/config.yml').resolve().absolute() +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument('config') +args = parser.parse_args() + +success, config, msg = load_config(args.config) cluster_config = RedisClusterStore('cluster_config') cluster_config.clear() cluster_config.load(parse_backends(config)) +on, off = get_backends() +load_backend_cycle('backend_cycler', on + off) t = Thread(target=cluster_worker) t.daemon = True t.start() while True: - x = get_best_backends() - print(x) + # online, offline = get_backends() + # print(online, offline) + # print(get_a_cluster_backend()) time.sleep(3)