diff --git a/README.md b/README.md index 429e246..ccfaaf4 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,9 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database. ### Use +If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database. +Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app` ### To Do diff --git a/daemon.py b/daemon.py index 20ec300..69e8532 100644 --- a/daemon.py +++ b/daemon.py @@ -1,22 +1,19 @@ -import time - -from llm_server.routes.cache import redis - -try: - import gevent.monkey - - gevent.monkey.patch_all() -except ImportError: - pass - +import argparse +import logging import os import sys +import time from pathlib import Path -from llm_server.config.load import load_config -from llm_server.database.create import create_db +from redis import Redis -from llm_server.workers.app import start_background +from llm_server.cluster.cluster_config import cluster_config +from llm_server.config.load import load_config, parse_backends +from llm_server.custom_redis import redis +from llm_server.database.create import create_db +from llm_server.logging import create_logger, logging_info, init_logging +from llm_server.routes.v1.generate_stats import generate_stats +from llm_server.workers.threader import start_background script_path = os.path.dirname(os.path.realpath(__file__)) config_path_environ = os.getenv("CONFIG_PATH") @@ -26,19 +23,46 @@ else: config_path = Path(script_path, 'config', 'config.yml') if __name__ == "__main__": - flushed_keys = redis.flush() - print('Flushed', len(flushed_keys), 'keys from Redis.') + parser = argparse.ArgumentParser(description='Daemon microservice.') + parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.") + parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.') + args = parser.parse_args() - success, config, msg = load_config(config_path, script_path) + # TODO: have this be set by either the arg or a config value + if args.debug: + logging_info.level = logging.DEBUG + + init_logging() + logger = create_logger('daemon') + logger.debug('Debug logging enabled.') + + if not args.no_reset: + Redis().flushall() + logger.info('Flushed Redis.') + + success, config, msg = load_config(config_path) if not success: - print('Failed to load config:', msg) + logger.info(f'Failed to load config: {msg}') sys.exit(1) create_db() + + cluster_config.clear() + cluster_config.load(parse_backends(config)) + + logger.info('Loading backend stats...') + generate_stats(regen=True) + start_background() - redis.set('daemon_started', 1) - print('== Daemon Setup Complete ==\n') + # Give some time for the background threads to get themselves ready to go. + time.sleep(2) - while True: - time.sleep(3600) + redis.set('daemon_started', 1) + logger.info('== Daemon Setup Complete ==') + + try: + while True: + time.sleep(3600) + except KeyboardInterrupt: + redis.set('daemon_started', 0) diff --git a/llm_server/cluster/__init__.py b/llm_server/cluster/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py new file mode 100644 index 0000000..9e2e19b --- /dev/null +++ b/llm_server/cluster/backend.py @@ -0,0 +1,117 @@ +import numpy as np + +from llm_server import opts +from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend +from llm_server.cluster.stores import redis_running_models +from llm_server.custom_redis import redis +from llm_server.llm.generator import generator +from llm_server.llm.info import get_info +from llm_server.llm.vllm.vllm_backend import VLLMBackend +from llm_server.routes.queue import priority_queue +from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model + + +def get_backends_from_model(model_name: str): + return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)] + + +def get_running_models(): + return redis_running_models.keys() + + +def purge_backend_from_running_models(backend_url: str): + keys = redis_running_models.keys() + pipeline = redis_running_models.pipeline() + for model in keys: + pipeline.srem(model, backend_url) + pipeline.execute() + + +def is_valid_model(model_name: str): + return redis_running_models.exists(model_name) + + +def test_backend(backend_url: str, test_prompt: bool = False): + backend_info = cluster_config.get_backend(backend_url) + if test_prompt: + handler = VLLMBackend(backend_url) + parameters, _ = handler.get_parameters({ + "stream": False, + "temperature": 0, + "max_new_tokens": 3, + }) + data = { + 'prompt': 'test prompt', + **parameters + } + try: + success, response, err = generator(data, backend_url, timeout=10) + if not success or not response or err: + return False, {} + except: + return False, {} + i = get_info(backend_url, backend_info['mode']) + if not i.get('model'): + return False, {} + return True, i + + +def get_model_choices(regen: bool = False): + if not regen: + c = redis.getp('model_choices') + if c: + return c + + base_client_api = redis.get('base_client_api', dtype=str) + running_models = get_running_models() + model_choices = {} + for model in running_models: + b = get_backends_from_model(model) + + context_size = [] + avg_gen_per_worker = [] + concurrent_gens = 0 + for backend_url in b: + backend_info = cluster_config.get_backend(backend_url) + if backend_info.get('model_config'): + context_size.append(backend_info['model_config']['max_position_embeddings']) + if backend_info.get('average_generation_elapsed_sec'): + avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec']) + concurrent_gens += backend_info['concurrent_gens'] + + active_gen_workers = get_active_gen_workers_model(model) + proompters_in_queue = priority_queue.len(model) + + if len(avg_gen_per_worker): + average_generation_elapsed_sec = np.average(avg_gen_per_worker) + else: + average_generation_elapsed_sec = 0 + estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, concurrent_gens, active_gen_workers) + + model_choices[model] = { + 'model': model, + 'client_api': f'https://{base_client_api}/{model}', + 'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None, + 'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled', + 'backend_count': len(b), + 'estimated_wait': estimated_wait_sec, + 'queued': proompters_in_queue, + 'processing': active_gen_workers, + 'avg_generation_time': average_generation_elapsed_sec, + 'concurrent_gens': concurrent_gens + } + + if len(context_size): + model_choices[model]['context_size'] = min(context_size) + + # Python wants to sort lowercase vs. uppercase letters differently. + model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper())) + + default_backend_url = get_a_cluster_backend() + default_backend_info = cluster_config.get_backend(default_backend_url) + if not default_backend_info.get('model'): + return {}, None + default_model = default_backend_info['model'] + + redis.setp('model_choices', (model_choices, default_model)) + return model_choices, default_model diff --git a/llm_server/cluster/cluster_config.py b/llm_server/cluster/cluster_config.py new file mode 100644 index 0000000..891dfc1 --- /dev/null +++ b/llm_server/cluster/cluster_config.py @@ -0,0 +1,124 @@ +import hashlib +import pickle +import traceback + +from llm_server import opts +from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle +from llm_server.cluster.stores import redis_running_models +from llm_server.custom_redis import RedisCustom +from llm_server.routes.helpers.model import estimate_model_size + + +class RedisClusterStore: + def __init__(self, name: str, **kwargs): + self.name = name + self.config_redis = RedisCustom(name, **kwargs) + + def clear(self): + self.config_redis.flush() + + def load(self, config: dict): + for k, v in config.items(): + self.add_backend(k, v) + + def add_backend(self, name: str, values: dict): + self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()}) + self.set_backend_value(name, 'online', False) + h = hashlib.sha256(name.encode('utf-8')).hexdigest() + self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}') + + def set_backend_value(self, backend: str, key: str, value): + # By storing the value as a pickle we don't have to cast anything when getting the value from Redis. + self.config_redis.hset(backend, key, pickle.dumps(value)) + + def get_backend(self, name: str): + r = self.config_redis.hgetall(name) + output = {} + for k, v in r.items(): + output[k.decode('utf8')] = pickle.loads(v) + return output + + def all(self): + keys = self.config_redis.keys('*') + if keys: + result = {} + for key in keys: + if key != f'{self.name}:____': + v = self.get_backend(key) + result[key] = v + return result + else: + return {} + + def validate_backend(self, backend_url: str): + """ + Returns the backend URL that was given, or a new one if that was offline. + :param backend_url: + :return: + """ + backend_info = self.get_backend(backend_url) + if not backend_info['online']: + old = backend_url + backend_url = get_a_cluster_backend() + print(f'Backend {old} offline. Request was redirected to {backend_url}') + return backend_url + + +cluster_config = RedisClusterStore('cluster_config') + + +def get_backends(): + backends = cluster_config.all() + result = {} + for k, v in backends.items(): + b = cluster_config.get_backend(k) + status = b.get('online', False) + priority = b['priority'] + result[k] = {'status': status, 'priority': priority} + + try: + if not opts.prioritize_by_size: + online_backends = sorted( + ((url, info) for url, info in backends.items() if info['online']), + key=lambda kv: -kv[1]['priority'], + reverse=True + ) + else: + online_backends = sorted( + ((url, info) for url, info in backends.items() if info['online']), + key=lambda kv: estimate_model_size(kv[1]['model_config']), + reverse=True + ) + offline_backends = sorted( + ((url, info) for url, info in backends.items() if not info['online']), + key=lambda kv: -kv[1]['priority'], + reverse=True + ) + return [url for url, info in online_backends], [url for url, info in offline_backends] + except KeyError: + traceback.print_exc() + print(backends) + + +def get_a_cluster_backend(model=None): + """ + Get a backend from Redis. If there are no online backends, return None. + If `model` is not supplied, we will pick one ourself. + """ + if model: + # First, determine if there are multiple backends hosting the same model. + backends_hosting_model = [i.decode('utf-8') for i in redis_running_models.smembers(model)] + + # If so, create an iterator for those backends + if len(backends_hosting_model): + add_backend_cycler(model, backends_hosting_model) + cycled = redis_cycle(model) + if len(cycled): + return cycled[0] + else: + # No backend hosting that model + return None + else: + online, _ = get_backends() + if len(online): + return online[0] diff --git a/llm_server/cluster/redis_cycle.py b/llm_server/cluster/redis_cycle.py new file mode 100644 index 0000000..266241d --- /dev/null +++ b/llm_server/cluster/redis_cycle.py @@ -0,0 +1,39 @@ +import redis + +redis_cycler_db = redis.Redis(host='localhost', port=6379, db=9) + + +def redis_cycle(list_name): + """ + Emulates itertools.cycle() but returns the complete shuffled list. + :param list_name: + :return: + """ + pipeline = redis_cycler_db.pipeline() + pipeline.lpop(list_name) + to_move = pipeline.execute()[0] + if not to_move: + return [] + pipeline.rpush(list_name, to_move) + pipeline.lrange(list_name, 0, -1) + results = pipeline.execute() + new_list = results[-1] + return [x.decode('utf-8') for x in new_list] + + +def add_backend_cycler(list_name: str, new_elements: list): + existing_elements = [i.decode('utf-8') for i in redis_cycler_db.lrange(list_name, 0, -1)] + existing_set = set(existing_elements) + + with redis_cycler_db.pipeline() as pipe: + # Add elements + for element in new_elements: + if element not in existing_set: + pipe.rpush(list_name, element) + + # Remove elements + for element in existing_set: + if element not in new_elements: + pipe.lrem(list_name, 0, element) + + pipe.execute() diff --git a/llm_server/cluster/stores.py b/llm_server/cluster/stores.py new file mode 100644 index 0000000..c0cbdcc --- /dev/null +++ b/llm_server/cluster/stores.py @@ -0,0 +1,3 @@ +from llm_server.custom_redis import RedisCustom + +redis_running_models = RedisCustom('running_models') diff --git a/llm_server/cluster/worker.py b/llm_server/cluster/worker.py new file mode 100644 index 0000000..9652db9 --- /dev/null +++ b/llm_server/cluster/worker.py @@ -0,0 +1,38 @@ +import time +from threading import Thread + +from llm_server.cluster.backend import test_backend +from llm_server.cluster.cluster_config import cluster_config +from llm_server.cluster.stores import redis_running_models + + +def cluster_worker(): + counter = 0 + while True: + test_prompt = False + if counter % 4 == 0: + # Only send a test prompt every 120 seconds. + test_prompt = True + threads = [] + for n, v in cluster_config.all().items(): + thread = Thread(target=check_backend, args=(n, v, test_prompt)) + thread.start() + threads.append(thread) + for thread in threads: + thread.join() + time.sleep(15) + counter += 1 + + +def check_backend(n, v, test_prompt): + online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt) + if online: + running_model = backend_info['model'] + for k, v in backend_info.items(): + cluster_config.set_backend_value(n, k, v) + redis_running_models.sadd(running_model, n) + else: + for model in redis_running_models.keys(): + redis_running_models.srem(model, n) + + cluster_config.set_backend_value(n, 'online', online) diff --git a/llm_server/config/config.py b/llm_server/config/config.py index 59568d7..2c08544 100644 --- a/llm_server/config/config.py +++ b/llm_server/config/config.py @@ -28,16 +28,19 @@ config_default_vars = { 'openai_force_no_hashes': True, 'include_system_tokens_in_stats': True, 'openai_moderation_scan_last_n': 5, - 'openai_moderation_workers': 10, 'openai_org_name': 'OpenAI', 'openai_silent_trim': False, 'openai_moderation_enabled': True, - 'netdata_root': None + 'netdata_root': None, + 'show_backends': True, + 'background_homepage_cacher': True, + 'openai_moderation_timeout': 5, + 'prioritize_by_size': False } -config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] +config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name'] mode_ui_names = { - 'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), + 'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), 'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), } diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 64469b2..cc3250c 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -3,38 +3,28 @@ import sys import openai +import llm_server from llm_server import opts from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars +from llm_server.custom_redis import redis from llm_server.database.conn import database from llm_server.database.database import get_number_of_rows -from llm_server.helpers import resolve_path -from llm_server.routes.cache import redis +from llm_server.routes.queue import PriorityQueue -def load_config(config_path, script_path): +def load_config(config_path): config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) success, config, msg = config_loader.load_config() if not success: return success, config, msg - # Resolve relative directory to the directory of the script - if config['database_path'].startswith('./'): - config['database_path'] = resolve_path(script_path, config['database_path'].strip('./')) - - if config['mode'] not in ['oobabooga', 'vllm']: - print('Unknown mode:', config['mode']) - sys.exit(1) - # TODO: this is atrocious - opts.mode = config['mode'] opts.auth_required = config['auth_required'] opts.log_prompts = config['log_prompts'] - opts.concurrent_gens = config['concurrent_gens'] opts.frontend_api_client = config['frontend_api_client'] - opts.context_size = config['token_limit'] opts.show_num_prompts = config['show_num_prompts'] opts.show_uptime = config['show_uptime'] - opts.backend_url = config['backend_url'].strip('/') + opts.cluster = config['cluster'] opts.show_total_output_tokens = config['show_total_output_tokens'] opts.netdata_root = config['netdata_root'] opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] @@ -53,10 +43,20 @@ def load_config(config_path, script_path): opts.openai_force_no_hashes = config['openai_force_no_hashes'] opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats'] opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n'] - opts.openai_moderation_workers = config['openai_moderation_workers'] opts.openai_org_name = config['openai_org_name'] opts.openai_silent_trim = config['openai_silent_trim'] opts.openai_moderation_enabled = config['openai_moderation_enabled'] + opts.show_backends = config['show_backends'] + opts.background_homepage_cacher = config['background_homepage_cacher'] + opts.openai_moderation_timeout = config['openai_moderation_timeout'] + opts.frontend_api_mode = config['frontend_api_mode'] + opts.prioritize_by_size = config['prioritize_by_size'] + + # Scale the number of workers. + for item in config['cluster']: + opts.cluster_workers += item['concurrent_gens'] + + llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']]) if opts.openai_expose_our_model and not opts.openai_api_key: print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') @@ -78,6 +78,16 @@ def load_config(config_path, script_path): if config['load_num_prompts']: redis.set('proompts', get_number_of_rows('prompts')) - redis.set('backend_mode', opts.mode) - return success, config, msg + + +def parse_backends(config): + if not config.get('cluster'): + return False + cluster = config.get('cluster') + config = {} + for item in cluster: + backend_url = item['backend_url'].strip('/') + item['backend_url'] = backend_url + config[backend_url] = item + return config diff --git a/llm_server/config/redis_config.py b/llm_server/config/redis_config.py new file mode 100644 index 0000000..06ab1d3 --- /dev/null +++ b/llm_server/config/redis_config.py @@ -0,0 +1,3 @@ +from llm_server.custom_redis import RedisCustom + +redis_config = RedisCustom('redis_config') diff --git a/llm_server/routes/cache.py b/llm_server/custom_redis.py similarity index 52% rename from llm_server/routes/cache.py rename to llm_server/custom_redis.py index d7046db..a055537 100644 --- a/llm_server/routes/cache.py +++ b/llm_server/custom_redis.py @@ -1,24 +1,27 @@ +import pickle import sys import traceback -from typing import Callable, List, Mapping, Union +from typing import Callable, List, Mapping, Optional, Union import redis as redis_pkg import simplejson as json from flask_caching import Cache from redis import Redis -from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT +from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT -flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) +flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) ONE_MONTH_SECONDS = 2678000 -class RedisWrapper: +class RedisCustom(Redis): """ - A wrapper class to set prefixes to keys. + A simple wrapper class for Redis to create a "namespace" within a DB, + which simplyifies key management. """ def __init__(self, prefix, **kwargs): + super().__init__() self.redis = Redis(**kwargs) self.prefix = prefix try: @@ -34,12 +37,11 @@ class RedisWrapper: def set(self, key, value, ex: Union[ExpiryT, None] = None): return self.redis.set(self._key(key), value, ex=ex) - def get(self, key, dtype=None, default=None): - """ - :param key: - :param dtype: convert to this type - :return: - """ + def get(self, key, default=None, dtype=None): + # TODO: use pickle + import inspect + if inspect.isclass(default): + raise Exception d = self.redis.get(self._key(key)) if dtype and d: @@ -108,7 +110,10 @@ class RedisWrapper: ): return self.redis.hincrby(self._key(name), key, amount) - def hdel(self, name: str, *keys: List): + def zcard(self, name: KeyT): + return self.redis.zcard(self._key(name)) + + def hdel(self, name: str, *keys: str): return self.redis.hdel(self._key(name), *keys) def hget( @@ -129,9 +134,62 @@ class RedisWrapper: ): return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt) + def lpush(self, name: str, *values: FieldT): + return self.redis.lpush(self._key(name), *values) + + def hset( + self, + name: str, + key: Optional = None, + value=None, + mapping: Optional[dict] = None, + items: Optional[list] = None, + ): + return self.redis.hset(self._key(name), key, value, mapping, items) + def hkeys(self, name: str): return self.redis.hkeys(self._key(name)) + def hmget(self, name: str, keys: List, *args: List): + return self.redis.hmget(self._key(name), keys, *args) + + def hgetall(self, name: str): + return self.redis.hgetall(self._key(name)) + + def keys(self, pattern: PatternT = "*", **kwargs): + raw_keys = self.redis.keys(self._key(pattern), **kwargs) + keys = [] + for key in raw_keys: + p = key.decode('utf-8').split(':') + if len(p) >= 2: + # Delete prefix + del p[0] + k = ':'.join(p) + if k != '____': + keys.append(k) + return keys + + def pipeline(self, transaction=True, shard_hint=None): + return self.redis.pipeline(transaction, shard_hint) + + def smembers(self, name: str): + return self.redis.smembers(self._key(name)) + + def spop(self, name: str, count: Optional[int] = None): + return self.redis.spop(self._key(name), count) + + def rpoplpush(self, src, dst): + return self.redis.rpoplpush(src, dst) + + def zpopmin(self, name: KeyT, count: Union[int, None] = None): + return self.redis.zpopmin(self._key(name), count) + + def exists(self, *names: KeyT): + n = [] + for name in names: + n.append(self._key(name)) + return self.redis.exists(*n) + def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): return self.set(key, json.dumps(dict_value), ex=ex) @@ -142,6 +200,15 @@ class RedisWrapper: else: return json.loads(r.decode("utf-8")) + def setp(self, name, value): + self.redis.set(self._key(name), pickle.dumps(value)) + + def getp(self, name: str): + r = self.redis.get(self._key(name)) + if r: + return pickle.loads(r) + return r + def flush(self): flushed = [] for key in self.redis.scan_iter(f'{self.prefix}:*'): @@ -149,5 +216,40 @@ class RedisWrapper: self.redis.delete(key) return flushed + def flushall(self, asynchronous: bool = ..., **kwargs) -> bool: + self.flush() + return True -redis = RedisWrapper('local_llm') + def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool: + self.flush() + return True + + def lrange(self, name: str, start: int, end: int): + return self.redis.lrange(self._key(name), start, end) + + def delete(self, *names: KeyT): + return self.redis.delete(*[self._key(i) for i in names]) + + def lpop(self, name: str, count: Optional[int] = None): + return self.redis.lpop(self._key(name), count) + + def zrange( + self, + name: KeyT, + start: int, + end: int, + desc: bool = False, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + byscore: bool = False, + bylex: bool = False, + offset: int = None, + num: int = None, + ): + return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num) + + def zrem(self, name: KeyT, *values: FieldT): + return self.redis.zrem(self._key(name), *values) + + +redis = RedisCustom('local_llm') diff --git a/llm_server/database/conn.py b/llm_server/database/conn.py index 25f3326..f63f555 100644 --- a/llm_server/database/conn.py +++ b/llm_server/database/conn.py @@ -5,20 +5,20 @@ class DatabaseConnection: host: str = None username: str = None password: str = None - database: str = None + database_name: str = None - def init_db(self, host, username, password, database): + def init_db(self, host, username, password, database_name): self.host = host self.username = username self.password = password - self.database = database + self.database_name = database_name def cursor(self): db = pymysql.connect( host=self.host, user=self.username, password=self.password, - database=self.database, + database=self.database_name, charset='utf8mb4', autocommit=True, ) diff --git a/llm_server/database/database.py b/llm_server/database/database.py index 9bfe578..d6bd6b2 100644 --- a/llm_server/database/database.py +++ b/llm_server/database/database.py @@ -1,15 +1,19 @@ import json import time import traceback +from typing import Union -import llm_server from llm_server import opts +from llm_server.cluster.cluster_config import cluster_config from llm_server.database.conn import database -from llm_server.llm.vllm import tokenize -from llm_server.routes.cache import redis +from llm_server.llm import get_token_count -def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): +def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False): + assert isinstance(prompt, str) + assert isinstance(backend_url, str) + + # Try not to shove JSON into the database. if isinstance(response, dict) and response.get('results'): response = response['results'][0]['text'] try: @@ -19,10 +23,11 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe except: pass - prompt_tokens = llm_server.llm.get_token_count(prompt) + prompt_tokens = get_token_count(prompt, backend_url) + if not is_error: if not response_tokens: - response_tokens = llm_server.llm.get_token_count(response) + response_tokens = get_token_count(response, backend_url) else: response_tokens = None @@ -43,7 +48,9 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe if token: increment_token_uses(token) - running_model = redis.get('running_model', str, 'ERROR') + backend_info = cluster_config.get_backend(backend_url) + running_model = backend_info.get('model') + backend_mode = backend_info['mode'] timestamp = int(time.time()) cursor = database.cursor() try: @@ -52,7 +59,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """, - (ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) + (ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) finally: cursor.close() @@ -179,3 +186,21 @@ def increment_token_uses(token): cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,)) finally: cursor.close() + + +def get_token_ratelimit(token): + priority = 9990 + simultaneous_ip = opts.simultaneous_requests_per_ip + if token: + cursor = database.cursor() + try: + cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,)) + result = cursor.fetchone() + if result: + priority, simultaneous_ip = result + if simultaneous_ip is None: + # No ratelimit for this token if null + simultaneous_ip = 999999999 + finally: + cursor.close() + return priority, simultaneous_ip diff --git a/llm_server/database/log_to_db.py b/llm_server/database/log_to_db.py new file mode 100644 index 0000000..75bcaab --- /dev/null +++ b/llm_server/database/log_to_db.py @@ -0,0 +1,30 @@ +import pickle +from typing import Union + +from redis import Redis + + +def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False): + assert isinstance(prompt, str) + assert isinstance(backend_url, str) + + r = Redis(host='localhost', port=6379, db=3) + data = { + 'function': 'log_prompt', + 'args': [], + 'kwargs': { + 'ip': ip, + 'token': token, + 'prompt': prompt, + 'response': response, + 'gen_time': gen_time, + 'parameters': parameters, + 'headers': dict(headers) if headers else headers, + 'backend_response_code': backend_response_code, + 'request_url': request_url, + 'backend_url': backend_url, + 'response_tokens': response_tokens, + 'is_error': is_error + } + } + r.publish('database-logger', pickle.dumps(data)) diff --git a/llm_server/helpers.py b/llm_server/helpers.py index 44b436b..91f3b15 100644 --- a/llm_server/helpers.py +++ b/llm_server/helpers.py @@ -8,7 +8,7 @@ import simplejson as json from flask import make_response from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def resolve_path(*p: str): @@ -54,13 +54,14 @@ def jsonify_pretty(json_dict: Union[list, dict], status=200, indent=4, sort_keys def round_up_base(n, base): if base == 0: - print('round_up_base DIVIDE BY ZERO ERROR????', n, base) + # TODO: I don't think passing (0, 0) to this function is a sign of any underlying issues. + # print('round_up_base DIVIDE BY ZERO ERROR????', n, base) return 0 return math.ceil(n / base) * base def auto_set_base_client_api(request): - http_host = redis.get('http_host', str) + http_host = redis.get('http_host', dtype=str) host = request.headers.get("Host") if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host): # If the current http_host is not an IP, don't do anything. diff --git a/llm_server/integer.py b/llm_server/integer.py deleted file mode 100644 index 1410dd1..0000000 --- a/llm_server/integer.py +++ /dev/null @@ -1,12 +0,0 @@ -import threading - - -class ThreadSafeInteger: - def __init__(self, value=0): - self.value = value - self._value_lock = threading.Lock() - - def increment(self): - with self._value_lock: - self.value += 1 - return self.value diff --git a/llm_server/llm/__init__.py b/llm_server/llm/__init__.py index 742b1a5..ac6702b 100644 --- a/llm_server/llm/__init__.py +++ b/llm_server/llm/__init__.py @@ -1,11 +1,30 @@ +import tiktoken + +from llm_server.cluster.cluster_config import cluster_config from llm_server.llm import oobabooga, vllm -from llm_server.routes.cache import redis +from llm_server.logging import create_logger -def get_token_count(prompt: str): - backend_mode = redis.get('backend_mode', str) +def fallback_tokenizer(prompt: str): + tokenizer = tiktoken.get_encoding("cl100k_base") + return len(tokenizer.encode(prompt)) + 10 + + +def get_token_count(prompt: str, backend_url: str): + backend_url = cluster_config.validate_backend(backend_url) + if not backend_url: + logger = create_logger('tokenizer') + logger.warning('using fallback tokenizer as there is no valid backend') + return fallback_tokenizer(prompt) + + backend_mode = cluster_config.get_backend(backend_url).get('mode') + if not backend_mode: + logger = create_logger('tokenizer') + logger.warning("using fallback tokenizer as the backend isn't initalized") + return fallback_tokenizer(prompt) + if backend_mode == 'vllm': - return vllm.tokenize(prompt) + return vllm.tokenize(prompt, backend_url) elif backend_mode == 'ooba': return oobabooga.tokenize(prompt) else: diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py index 5dd2093..c924d38 100644 --- a/llm_server/llm/generator.py +++ b/llm_server/llm/generator.py @@ -1,14 +1,15 @@ from llm_server import opts +from llm_server.cluster.cluster_config import cluster_config -def generator(request_json_body): - if opts.mode == 'oobabooga': +def generator(request_json_body, cluster_backend, timeout: int = None): + mode = cluster_config.get_backend(cluster_backend)['mode'] + if mode == 'ooba': # from .oobabooga.generate import generate # return generate(request_json_body) raise NotImplementedError - elif opts.mode == 'vllm': + elif mode == 'vllm': from .vllm.generate import generate - r = generate(request_json_body) - return r + return generate(request_json_body, cluster_backend, timeout=timeout) else: raise Exception diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py index 5a529ba..d1218e2 100644 --- a/llm_server/llm/info.py +++ b/llm_server/llm/info.py @@ -3,23 +3,35 @@ import requests from llm_server import opts -def get_running_model(): - # TODO: cache the results for 1 min so we don't have to keep calling the backend - # TODO: only use one try/catch - - if opts.mode == 'oobabooga': +def get_running_model(backend_url: str, mode: str): + if mode == 'ooba': try: - backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) + backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() return r_json['result'], None except Exception as e: return False, e - elif opts.mode == 'vllm': + elif mode == 'vllm': try: - backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) + backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) r_json = backend_response.json() return r_json['model'], None except Exception as e: return False, e else: raise Exception + + +def get_info(backend_url: str, mode: str): + if mode == 'ooba': + return {} + # raise NotImplementedError + elif mode == 'vllm': + try: + r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout) + j = r.json() + except Exception as e: + return {} + return j + else: + raise Exception diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py index 1c11c17..f864b18 100644 --- a/llm_server/llm/llm_backend.py +++ b/llm_server/llm/llm_backend.py @@ -2,14 +2,17 @@ from typing import Tuple, Union import flask -from llm_server import opts +from llm_server.cluster.cluster_config import cluster_config from llm_server.llm import get_token_count -from llm_server.routes.cache import redis class LLMBackend: _default_params: dict + def __init__(self, backend_url: str): + self.backend_url = backend_url + self.backend_info = cluster_config.get_backend(self.backend_url) + def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers): raise NotImplementedError @@ -32,14 +35,16 @@ class LLMBackend: """ If a backend needs to do other checks not related to the prompt or parameters. Default is no extra checks preformed. + :param request: + :param prompt: :param parameters: :return: """ return True, None def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: - prompt_len = get_token_count(prompt) - if prompt_len > opts.context_size - 10: - model_name = redis.get('running_model', str, 'NO MODEL ERROR') - return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size' + prompt_len = get_token_count(prompt, self.backend_url) + token_limit = self.backend_info['model_config']['max_position_embeddings'] + if prompt_len > token_limit - 10: + return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size' return True, None diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index 4336756..18fe6b1 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -1,78 +1,6 @@ -from flask import jsonify - from ..llm_backend import LLMBackend -from ...database.database import log_prompt -from ...helpers import safe_list_get -from ...routes.cache import redis -from ...routes.helpers.client import format_sillytavern_err -from ...routes.helpers.http import validate_json class OobaboogaBackend(LLMBackend): - default_params = {} - - def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): - raise NotImplementedError('need to implement default_params') - - backend_err = False - response_valid_json, response_json_body = validate_json(response) - if response: - try: - # Be extra careful when getting attributes from the response object - response_status_code = response.status_code - except: - response_status_code = 0 - else: - response_status_code = None - - # =============================================== - - # We encountered an error - if not success or not response or error_msg: - if not error_msg or error_msg == '': - error_msg = 'Unknown error.' - else: - error_msg = error_msg.strip('.') + '.' - backend_response = format_sillytavern_err(error_msg, 'error') - log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True) - return jsonify({ - 'code': 500, - 'msg': error_msg, - 'results': [{'text': backend_response}] - }), 400 - - # =============================================== - - if response_valid_json: - backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') - if not backend_response: - # Ooba doesn't return any error messages so we will just tell the client an error occurred - backend_err = True - backend_response = format_sillytavern_err( - f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.', - 'error') - response_json_body['results'][0]['text'] = backend_response - - if not backend_err: - redis.incr('proompts') - - log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) - return jsonify({ - **response_json_body - }), 200 - else: - backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') - log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True) - return jsonify({ - 'code': 500, - 'msg': 'the backend did not return valid JSON', - 'results': [{'text': backend_response}] - }), 400 - - def validate_params(self, params_dict: dict): - # No validation required - return True, None - - def get_parameters(self, parameters): - del parameters['prompt'] - return parameters + def __int__(self): + return diff --git a/llm_server/llm/openai/moderation.py b/llm_server/llm/openai/moderation.py index 53e234d..f62241d 100644 --- a/llm_server/llm/openai/moderation.py +++ b/llm_server/llm/openai/moderation.py @@ -10,7 +10,7 @@ def check_moderation_endpoint(prompt: str): } response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10) if response.status_code != 200: - print(response.text) + print('moderation failed:', response) response.raise_for_status() response = response.json() diff --git a/llm_server/llm/openai/oai_to_vllm.py b/llm_server/llm/openai/oai_to_vllm.py new file mode 100644 index 0000000..ef07a08 --- /dev/null +++ b/llm_server/llm/openai/oai_to_vllm.py @@ -0,0 +1,97 @@ +from flask import jsonify + +from llm_server import opts + + +def oai_to_vllm(request_json_body, stop_hashes: bool, mode): + if not request_json_body.get('stop'): + request_json_body['stop'] = [] + if not isinstance(request_json_body['stop'], list): + # It is a string, so create a list with the existing element. + request_json_body['stop'] = [request_json_body['stop']] + + if stop_hashes: + if opts.openai_force_no_hashes: + request_json_body['stop'].append('###') + else: + # TODO: make stopping strings a configurable + request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT']) + else: + request_json_body['stop'].extend(['user:', 'assistant:']) + + if request_json_body.get('frequency_penalty', 0) < -2: + request_json_body['frequency_penalty'] = -2 + elif request_json_body.get('frequency_penalty', 0) > 2: + request_json_body['frequency_penalty'] = 2 + + if mode == 'vllm' and request_json_body.get('top_p') == 0: + request_json_body['top_p'] = 0.01 + + request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens) + if request_json_body['max_tokens'] == 0: + # We don't want to set any defaults here. + del request_json_body['max_tokens'] + + return request_json_body + + +def format_oai_err(err_msg): + print('OAI ERROR MESSAGE:', err_msg) + return jsonify({ + "error": { + "message": err_msg, + "type": "invalid_request_error", + "param": None, + "code": None + } + }), 400 + + +def validate_oai(parameters): + if parameters.get('messages'): + for m in parameters['messages']: + if m['role'].lower() not in ['assistant', 'user', 'system']: + return format_oai_err('messages role must be assistant, user, or system') + + if parameters.get('temperature', 0) > 2: + return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'") + if parameters.get('temperature', 0) < 0: + return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'") + + if parameters.get('top_p', 1) > 2: + return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'") + if parameters.get('top_p', 1) < 0: + return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'") + + if parameters.get('presence_penalty', 1) > 2: + return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'") + if parameters.get('presence_penalty', 1) < -2: + return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'") + + if parameters.get('top_p', 1) > 2: + return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'") + if parameters.get('top_p', 1) < 0: + return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'") + + if parameters.get('top_p', 1) > 2: + return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'") + if parameters.get('top_p', 1) < 0: + return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'") + + if parameters.get('max_tokens', 2) < 1: + return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'") + + +def return_invalid_model_err(requested_model: str): + if requested_model: + msg = f"The model `{requested_model}` does not exist" + else: + msg = "The requested model does not exist" + return jsonify({ + "error": { + "message": msg, + "type": "invalid_request_error", + "param": None, + "code": "model_not_found" + } + }), 404 diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py index d5b64e3..daec3dc 100644 --- a/llm_server/llm/openai/transform.py +++ b/llm_server/llm/openai/transform.py @@ -2,86 +2,35 @@ import concurrent.futures import re import secrets import string -import time import traceback from typing import Dict, List import tiktoken -from flask import jsonify, make_response -import llm_server from llm_server import opts from llm_server.llm import get_token_count -from llm_server.routes.cache import redis ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line. -def build_openai_response(prompt, response, model=None): - # Seperate the user's prompt from the context - x = prompt.split('### USER:') - if len(x) > 1: - prompt = re.sub(r'\n$', '', x[-1].strip(' ')) - - # Make sure the bot doesn't put any other instructions in its response - # y = response.split('\n### ') - # if len(y) > 1: - # response = re.sub(r'\n$', '', y[0].strip(' ')) - response = re.sub(ANTI_RESPONSE_RE, '', response) - response = re.sub(ANTI_CONTINUATION_RE, '', response) - - # TODO: async/await - prompt_tokens = llm_server.llm.get_token_count(prompt) - response_tokens = llm_server.llm.get_token_count(response) - running_model = redis.get('running_model', str, 'ERROR') - - response = make_response(jsonify({ - "id": f"chatcmpl-{generate_oai_string(30)}", - "object": "chat.completion", - "created": int(time.time()), - "model": running_model if opts.openai_expose_our_model else model, - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": response, - }, - "logprobs": None, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": response_tokens, - "total_tokens": prompt_tokens + response_tokens - } - }), 200) - - stats = redis.get('proxy_stats', dict) - if stats: - response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] - return response - - def generate_oai_string(length=24): alphabet = string.ascii_letters + string.digits return ''.join(secrets.choice(alphabet) for i in range(length)) -def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]: - tokenizer = tiktoken.get_encoding("cl100k_base") - - def get_token_count_tiktoken_thread(msg): - return len(tokenizer.encode(msg["content"])) +def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]: + def get_token_count_thread(msg): + return get_token_count(msg["content"], backend_url) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: - token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt)) + token_counts = list(executor.map(get_token_count_thread, prompt)) total_tokens = sum(token_counts) - formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens + formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens # If total tokens exceed the limit, start trimming - if total_tokens > context_token_limit: + if total_tokens + formatting_tokens > context_token_limit: while True: while total_tokens + formatting_tokens > context_token_limit: # Calculate the index to start removing messages from @@ -94,22 +43,43 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) - if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt): break - def get_token_count_thread(msg): - return get_token_count(msg["content"]) - with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: token_counts = list(executor.map(get_token_count_thread, prompt)) total_tokens = sum(token_counts) - formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens - + formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens if total_tokens + formatting_tokens > context_token_limit: # Start over, but this time calculate the token count using the backend with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: token_counts = list(executor.map(get_token_count_thread, prompt)) else: break + return prompt + +def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str: + tokenizer = tiktoken.get_encoding("cl100k_base") + token_count = get_token_count(prompt, backend_url) + + # If total tokens exceed the limit, start trimming + if token_count > context_token_limit: + while True: + while token_count > context_token_limit: + # Calculate the index to start removing characters from + remove_index = len(prompt) // 3 + + while remove_index < len(prompt): + prompt = prompt[:remove_index] + prompt[remove_index + 100:] + token_count = len(tokenizer.encode(prompt)) + if token_count <= context_token_limit or remove_index == len(prompt): + break + + token_count = get_token_count(prompt, backend_url) + if token_count > context_token_limit: + # Start over, but this time calculate the token count using the backend + token_count = get_token_count(prompt, backend_url) + else: + break return prompt @@ -117,8 +87,9 @@ def transform_messages_to_prompt(oai_messages): try: prompt = f'### INSTRUCTION: {opts.openai_system_prompt}' for msg in oai_messages: - if not msg.get('content') or not msg.get('role'): + if 'content' not in msg.keys() or 'role' not in msg.keys(): return False + msg['content'] = str(msg['content']) # Prevent any weird issues. if msg['role'] == 'system': prompt += f'### INSTRUCTION: {msg["content"]}\n\n' elif msg['role'] == 'user': @@ -126,7 +97,7 @@ def transform_messages_to_prompt(oai_messages): elif msg['role'] == 'assistant': prompt += f'### ASSISTANT: {msg["content"]}\n\n' else: - return False + raise Exception(f'Unknown role: {msg["role"]}') except Exception as e: # TODO: use logging traceback.print_exc() diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py index 1549f2e..31cd511 100644 --- a/llm_server/llm/vllm/generate.py +++ b/llm_server/llm/vllm/generate.py @@ -1,80 +1,21 @@ """ This file is used by the worker that processes requests. """ -import json -import time -from uuid import uuid4 import requests -import llm_server from llm_server import opts -from llm_server.routes.cache import redis # TODO: make the VLMM backend return TPS and time elapsed # https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py def prepare_json(json_data: dict): - # logit_bias is not currently supported - # del json_data['logit_bias'] - # Convert back to VLLM. json_data['max_tokens'] = json_data.pop('max_new_tokens') return json_data -def transform_to_text(json_request, api_response): - """ - This is to convert a streaming request to a non-streamed request. Don't think this is nessesary. - :param json_request: - :param api_response: - :return: - """ - prompt = transform_prompt_to_text(json_request['messages']) - text = '' - finish_reason = None - for line in api_response.split('\n'): - if line.startswith('data:'): - try: - data = json.loads(line[5:].strip()) - except json.decoder.JSONDecodeError: - break - if 'choices' in data: - for choice in data['choices']: - if 'delta' in choice and 'content' in choice['delta']: - text += choice['delta']['content'] - if data['choices'][0]['finish_reason']: - finish_reason = data['choices'][0]['finish_reason'] - - prompt_tokens = len(llm_server.llm.get_token_count(prompt)) - completion_tokens = len(llm_server.llm.get_token_count(text)) - running_model = redis.get('running_model', str, 'ERROR') - - # https://platform.openai.com/docs/api-reference/making-requests?lang=python - return { - "id": str(uuid4()), - "object": "chat.completion", - "created": int(time.time()), - "model": running_model, - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": prompt_tokens + completion_tokens - }, - "choices": [ - { - "message": { - "role": "assistant", - "content": text - }, - "finish_reason": finish_reason, - "index": 0 - } - ] - } - - def transform_prompt_to_text(prompt: list): text = '' for item in prompt: @@ -82,26 +23,26 @@ def transform_prompt_to_text(prompt: list): return text.strip('\n') -def handle_blocking_request(json_data: dict): +def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10): try: - r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout) except requests.exceptions.ReadTimeout: - print(f'Failed to reach VLLM inference endpoint - request to backend timed out') + # print(f'Failed to reach VLLM inference endpoint - request to backend timed out') return False, None, 'Request to backend timed out' except Exception as e: - print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') + # print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') return False, None, 'Request to backend encountered error' if r.status_code != 200: - print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}') + # print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}') return False, r, f'Backend returned {r.status_code}' return True, r, None -def generate(json_data: dict): +def generate(json_data: dict, cluster_backend, timeout: int = None): if json_data.get('stream'): try: - return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout) except Exception as e: - print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') + return False else: - return handle_blocking_request(json_data) + return handle_blocking_request(json_data, cluster_backend, timeout=timeout) diff --git a/llm_server/llm/vllm/info.py b/llm_server/llm/vllm/info.py index 996c614..0142301 100644 --- a/llm_server/llm/vllm/info.py +++ b/llm_server/llm/vllm/info.py @@ -1,3 +1,7 @@ +import requests + +from llm_server import opts + vllm_info = """

Important: This endpoint is running vllm and not all Oobabooga parameters are supported.

Supported Parameters: """ \ No newline at end of file +""" diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index a698fd6..bdb6650 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -1,26 +1,51 @@ +import concurrent.futures + import requests import tiktoken from llm_server import opts +from llm_server.cluster.cluster_config import cluster_config +from llm_server.logging import create_logger -def tokenize(prompt: str) -> int: +def tokenize(prompt: str, backend_url: str) -> int: + assert backend_url + assert isinstance(backend_url, str) + if not prompt: # The tokenizers have issues when the prompt is None. return 0 + assert isinstance(prompt, str) + + logger = create_logger('tokenizer') + + # The backend could have died between when the request was + # submitted and now, so let's double check it's still online. + backend_url = cluster_config.validate_backend(backend_url) + tokenizer = tiktoken.get_encoding("cl100k_base") - # First we tokenize it locally to determine if it's worth sending it to the backend. - initial_estimate = len(tokenizer.encode(prompt)) - if initial_estimate <= opts.context_size + 200: + # Split the prompt into 2000 character chunks + chunk_size = 2000 + chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)] + + # Define a function to send a chunk to the server + def send_chunk(chunk): try: - r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) + r = requests.post(f'{backend_url}/tokenize', json={'input': chunk}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) j = r.json() return j['length'] except Exception as e: - print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') - return len(tokenizer.encode(prompt)) + 10 - else: - # If the result was greater than our context size, return the estimate. - # We won't be sending it through the backend so it does't need to be accurage. - return initial_estimate + logger.debug(f'Failed to tokenize using VLLM - {e.__class__.__name__}') + return len(tokenizer.encode(chunk)) + 10 + + # Use a ThreadPoolExecutor to send all chunks to the server at once + with concurrent.futures.ThreadPoolExecutor() as executor: + future_to_chunk = {executor.submit(send_chunk, chunk): chunk for chunk in chunks} + for future in concurrent.futures.as_completed(future_to_chunk): + chunk = future_to_chunk[future] + try: + data = future.result() + except Exception as exc: + logger.warning('%r generated an exception: %s' % (chunk, exc)) + return sum(future.result() for future in future_to_chunk) diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index e5b0fad..5c12b45 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -1,10 +1,9 @@ -import threading from typing import Tuple from flask import jsonify from vllm import SamplingParams -from llm_server.database.database import log_prompt +from llm_server.database.log_to_db import log_to_db from llm_server.llm.llm_backend import LLMBackend @@ -19,16 +18,8 @@ class VLLMBackend(LLMBackend): # Failsafe backend_response = '' - r_url = request.url - - def background_task(): - log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url, - response_tokens=response_json_body.get('details', {}).get('generated_tokens')) - - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task) - thread.start() - thread.join() + log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url, + response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url) return jsonify({'results': [{'text': backend_response}]}), 200 @@ -38,14 +29,20 @@ class VLLMBackend(LLMBackend): top_k = parameters.get('top_k', self._default_params['top_k']) if top_k <= 0: top_k = -1 + + # TODO: support more params sampling_params = SamplingParams( temperature=parameters.get('temperature', self._default_params['temperature']), top_p=parameters.get('top_p', self._default_params['top_p']), top_k=top_k, use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, - stop=parameters.get('stopping_strings', self._default_params['stop']), + stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))), ignore_eos=parameters.get('ban_eos_token', False), - max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) + max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']), + presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']), + frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']), + length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']), + early_stopping=parameters.get('early_stopping', self._default_params['early_stopping']) ) except ValueError as e: return None, str(e).strip('.') diff --git a/llm_server/logging.py b/llm_server/logging.py new file mode 100644 index 0000000..7e9aa74 --- /dev/null +++ b/llm_server/logging.py @@ -0,0 +1,52 @@ +import logging + +import coloredlogs + +from llm_server import opts + + +class LoggingInfo: + def __init__(self): + self._level = logging.INFO + self._format = opts.LOGGING_FORMAT + + @property + def level(self): + return self._level + + @level.setter + def level(self, value): + self._level = value + + @property + def format(self): + return self._format + + @format.setter + def format(self, value): + self._format = value + + +logging_info = LoggingInfo() + + +def init_logging(): + """ + Set up the parent logger. + :return: + """ + logger = logging.getLogger('llm_server') + logger.setLevel(logging_info.level) + + +def create_logger(name): + logger = logging.getLogger('llm_server').getChild(name) + logger.setLevel(logging_info.level) + if not logger.handlers: + handler = logging.StreamHandler() + handler.setLevel(logging_info.level) + formatter = logging.Formatter(logging_info.format) + handler.setFormatter(formatter) + logger.addHandler(handler) + coloredlogs.install(logger=logger, level=logging_info.level) + return logger diff --git a/llm_server/messages.py b/llm_server/messages.py new file mode 100644 index 0000000..c7e3eb7 --- /dev/null +++ b/llm_server/messages.py @@ -0,0 +1 @@ +BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.' diff --git a/llm_server/netdata.py b/llm_server/netdata.py deleted file mode 100644 index f37c109..0000000 --- a/llm_server/netdata.py +++ /dev/null @@ -1,52 +0,0 @@ -import json -from datetime import datetime, timedelta - -import requests - -from llm_server import opts - - -def get_power_states(): - gpu_num = 0 - output = {} - while True: - url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state" - try: - response = requests.get(url, timeout=10) - if response.status_code != 200: - break - data = json.loads(response.text) - power_state_data = data['data'][0] - power_state = None - for i in range(1, len(power_state_data)): - if power_state_data[i] == 1: - power_state = data['labels'][i] - break - output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p')) - except Exception as e: - print('Failed to fetch Netdata metrics:', e) - return output - gpu_num += 1 - return output - - -def get_gpu_wh(gpu_id: int): - chart_name = f"nvidia_smi.gpu{gpu_id}_power" - now = datetime.now() - one_hour_ago = now - timedelta(hours=1) - num_seconds = int((now - one_hour_ago).total_seconds()) - params = { - "chart": chart_name, - "after": int(one_hour_ago.timestamp()), - "before": int(now.timestamp()), - "points": num_seconds, - "group": "second", - "format": "json", - "options": "absolute|jsonwrap" - } - response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10) - data = json.loads(response.text) - total_power_usage_watts = sum(point[1] for point in data['result']['data']) - # total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1) - total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3) - return total_power_usage_kwh diff --git a/llm_server/opts.py b/llm_server/opts.py index de23c7a..f75ba94 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -1,12 +1,11 @@ # Read-only global variables +# Uppercase variables are read-only globals. +# Lowercase variables are ones that are set on startup and are never changed. + # TODO: rewrite the config system so I don't have to add every single config default here -running_model = 'ERROR' -concurrent_gens = 3 -mode = 'oobabooga' -backend_url = None -context_size = 5555 +frontend_api_mode = 'ooba' max_new_tokens = 500 auth_required = False log_prompts = False @@ -33,7 +32,15 @@ openai_expose_our_model = False openai_force_no_hashes = True include_system_tokens_in_stats = True openai_moderation_scan_last_n = 5 -openai_moderation_workers = 10 openai_org_name = 'OpenAI' openai_silent_trim = False openai_moderation_enabled = True +cluster = {} +show_backends = True +background_homepage_cacher = True +openai_moderation_timeout = 5 +prioritize_by_size = False +cluster_workers = 0 +redis_stream_timeout = 25000 + +LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s" diff --git a/llm_server/pre_fork.py b/llm_server/pre_fork.py index f3ea0f4..6e8c1ad 100644 --- a/llm_server/pre_fork.py +++ b/llm_server/pre_fork.py @@ -1,21 +1,9 @@ import sys -from redis import Redis - -from llm_server.routes.cache import redis -from llm_server.routes.v1.generate_stats import generate_stats +from llm_server.custom_redis import redis def server_startup(s): - if not redis.get('daemon_started', bool): + if not redis.get('daemon_started', dtype=bool): print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') sys.exit(1) - - # Flush the RedisPriorityQueue database. - queue_redis = Redis(host='localhost', port=6379, db=15) - for key in queue_redis.scan_iter('*'): - queue_redis.delete(key) - - # Cache the initial stats - print('Loading backend stats...') - generate_stats() diff --git a/llm_server/routes/helpers/client.py b/llm_server/routes/helpers/client.py index 48e721e..5031b8b 100644 --- a/llm_server/routes/helpers/client.py +++ b/llm_server/routes/helpers/client.py @@ -1,11 +1,18 @@ -from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis -def format_sillytavern_err(msg: str, level: str = 'info'): - http_host = redis.get('http_host', str) +def format_sillytavern_err(msg: str, backend_url: str = None, error_type: str = 'info'): + if backend_url: + cluster_backend_hash = cluster_config.get_backend(backend_url)['hash'] + else: + cluster_backend_hash = 'none' + http_host = redis.get('http_host', dtype=str) return f"""``` === MESSAGE FROM LLM MIDDLEWARE AT {http_host} === --> {level.upper()} <- +-> {error_type.upper()} <- {msg} +``` +``` +BACKEND: {cluster_backend_hash} ```""" diff --git a/llm_server/routes/helpers/http.py b/llm_server/routes/helpers/http.py index 2fa1190..a3f1906 100644 --- a/llm_server/routes/helpers/http.py +++ b/llm_server/routes/helpers/http.py @@ -100,4 +100,4 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas j = json.loads(str(data)) return True, j except Exception as e: - return False, e + return False, e \ No newline at end of file diff --git a/llm_server/routes/helpers/model.py b/llm_server/routes/helpers/model.py new file mode 100644 index 0000000..bf18b66 --- /dev/null +++ b/llm_server/routes/helpers/model.py @@ -0,0 +1,15 @@ +def estimate_model_size(config: dict): + """ + Estimate the size of a model from its config. No idea if this is correct, + but it allows us to compare models. + :param config: + :return: + """ + vocab_size = config.get('vocab_size') + hidden_size = config.get('hidden_size') + num_hidden_layers = config.get('num_hidden_layers') + intermediate_size = config.get('intermediate_size') + if vocab_size and hidden_size and num_hidden_layers and intermediate_size: + total_params = (vocab_size * hidden_size) + (num_hidden_layers * ((hidden_size * intermediate_size * 4) + (hidden_size * hidden_size * 3))) + return int(total_params / 1e9) + return 0 diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index d6b02e2..aadda78 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -3,8 +3,8 @@ from typing import Tuple import flask from flask import jsonify, request -from llm_server import opts -from llm_server.database.database import log_prompt +from llm_server import messages, opts +from llm_server.database.log_to_db import log_to_db from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.request_handler import RequestHandler @@ -13,8 +13,11 @@ class OobaRequestHandler(RequestHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - def handle_request(self): + def handle_request(self, return_ok: bool = True): assert not self.used + if self.offline: + print('This backend is offline:', messages.BACKEND_OFFLINE) + return self.handle_error(messages.BACKEND_OFFLINE) request_valid, invalid_response = self.validate_request() if not request_valid: @@ -25,14 +28,19 @@ class OobaRequestHandler(RequestHandler): llm_request = {**self.parameters, 'prompt': prompt} _, backend_response = self.generate_response(llm_request) - return backend_response + if return_ok: + # Always return 200 so ST displays our error messages + return backend_response[0], 200 + else: + # The OpenAI route needs to detect 429 errors. + return backend_response def handle_ratelimited(self, do_log: bool = True): msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' backend_response = self.handle_error(msg) if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) - return backend_response[0], 200 # We only return the response from handle_error(), not the error code + log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) + return backend_response[0], 429 def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' @@ -40,7 +48,7 @@ class OobaRequestHandler(RequestHandler): # TODO: how to format this response_msg = error_msg else: - response_msg = format_sillytavern_err(error_msg, error_type) + response_msg = format_sillytavern_err(error_msg, error_type=error_type, backend_url=self.backend_url) return jsonify({ 'results': [{'text': response_msg}] diff --git a/llm_server/routes/openai/__init__.py b/llm_server/routes/openai/__init__.py index 67febc9..c3837e4 100644 --- a/llm_server/routes/openai/__init__.py +++ b/llm_server/routes/openai/__init__.py @@ -5,9 +5,11 @@ from ..server_error import handle_server_error from ... import opts openai_bp = Blueprint('openai/v1/', __name__) +openai_model_bp = Blueprint('openai/', __name__) @openai_bp.before_request +@openai_model_bp.before_request def before_oai_request(): if not opts.enable_openi_compatible_backend: return 'The OpenAI-compatible backend is disabled.', 401 @@ -15,8 +17,22 @@ def before_oai_request(): @openai_bp.errorhandler(500) +@openai_model_bp.errorhandler(500) def handle_error(e): - return handle_server_error(e) + """ + Found Codes: + "auth_subrequest_error" + """ + + print('OAI returning error:', e) + return jsonify({ + "error": { + "message": "Internal server error", + "type": "auth_subrequest_error", + "param": None, + "code": "internal_error" + } + }), 500 from .models import openai_list_models diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index cc27dce..9ccc15f 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -1,113 +1,175 @@ import json -import threading import time import traceback +import ujson from flask import Response, jsonify, request +from redis import Redis -from . import openai_bp -from ..cache import redis +from llm_server.custom_redis import redis +from . import openai_bp, openai_model_bp from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler +from ..queue import priority_queue from ... import opts -from ...database.database import log_prompt -from ...llm.generator import generator -from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt -from ...llm.vllm import tokenize +from ...database.log_to_db import log_to_db +from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai +from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit # TODO: add rate-limit headers? + @openai_bp.route('/chat/completions', methods=['POST']) -def openai_chat_completions(): +@openai_model_bp.route('//v1/chat/completions', methods=['POST']) +def openai_chat_completions(model_name=None): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 else: - handler = OpenAIRequestHandler(request, request_json_body) - if request_json_body.get('stream'): - if not opts.enable_streaming: - # TODO: return a proper OAI error message - return 'disabled', 401 + handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) + if handler.offline: + return return_invalid_model_err(model_name) - if opts.mode != 'vllm': - # TODO: implement other backends - raise NotImplementedError - - response_status_code = 0 - start_time = time.time() - request_valid, invalid_response = handler.validate_request() - if not request_valid: - # TODO: simulate OAI here - raise Exception('TODO: simulate OAI here') - else: - handler.prompt = transform_messages_to_prompt(request_json_body['messages']) - msg_to_backend = { - **handler.parameters, - 'prompt': handler.prompt, - 'stream': True, - } - try: - response = generator(msg_to_backend) - r_headers = dict(request.headers) - r_url = request.url - model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model') - oai_string = generate_oai_string(30) - - def generate(): - generated_text = '' - partial_response = b'' - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(handler.prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue - - data = { - "id": f"chatcmpl-{oai_string}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": new - }, - "finish_reason": None - } - ] - } - yield f'data: {json.dumps(data)}\n\n' - - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time - - def background_task(): - generated_tokens = tokenize(generated_text) - log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) - - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task) - thread.start() - thread.join() - - return Response(generate(), mimetype='text/event-stream') - except: - # TODO: simulate OAI here - raise Exception - else: + if not request_json_body.get('stream'): try: return handler.handle_request() except Exception: traceback.print_exc() return 'Internal server error', 500 + else: + if not opts.enable_streaming: + return 'Streaming disabled', 403 + + invalid_oai_err_msg = validate_oai(handler.request_json_body) + if invalid_oai_err_msg: + return invalid_oai_err_msg + + handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode']) + + handler.parameters, e = handler.get_parameters() + handler.request_json_body = { + 'messages': handler.request_json_body['messages'], + 'model': handler.request_json_body['model'], + **handler.parameters + } + + if opts.openai_silent_trim: + handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)) + else: + handler.prompt = transform_messages_to_prompt(handler.request.json['messages']) + if not handler.prompt: + # Prevent issues on the backend. + return 'Invalid prompt', 400 + + # Need to set the prompt in the JSON body since that's what the inference worker expects. + handler.request_json_body['prompt'] = handler.prompt + + start_time = time.time() + + request_valid, invalid_response = handler.validate_request() + if not request_valid: + return invalid_response + + event = None + if not handler.is_client_ratelimited(): + event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) + if not event: + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + None, + None, + handler.parameters, + request.headers, + 429, + request.url, + handler.backend_url, + ) + return handler.handle_ratelimited() + + try: + r_headers = dict(request.headers) + r_url = request.url + model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') + oai_string = generate_oai_string(30) + + # Need to do this before we enter generate() since we want to be able to + # return a 408 if necessary. + _, stream_name, error_msg = event.wait() + if error_msg: + print('OAI failed to start streaming:', error_msg) + stream_name = None # set to null so that the Finally ignores it. + return 'Request Timeout', 408 + + def generate(): + stream_redis = Redis(db=8) + generated_text = '' + try: + last_id = '0-0' + while True: + stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) + if not stream_data: + print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") + yield 'data: [DONE]\n\n' + else: + for stream_index, item in stream_data[0][1]: + last_id = stream_index + timestamp = int(stream_index.decode('utf-8').split('-')[0]) + data = ujson.loads(item[b'data']) + if data['error']: + # Not printing error since we can just check the daemon log. + print('OAI streaming encountered error') + yield 'data: [DONE]\n\n' + return + elif data['new']: + response = { + "id": f"chatcmpl-{oai_string}", + "object": "chat.completion.chunk", + "created": timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": data['new'] + }, + "finish_reason": None + } + ] + } + generated_text = generated_text + data['new'] + yield f'data: {json.dumps(response)}\n\n' + elif data['completed']: + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + 200, + r_url, + handler.backend_url, + ) + return + except GeneratorExit: + return + except Exception: + traceback.print_exc() + yield 'data: [DONE]\n\n' + finally: + if event: + redis.publish(f'notifications:{event.event_id}', 'canceled') + if stream_name: + stream_redis.delete(stream_name) + + return Response(generate(), mimetype='text/event-stream') + except Exception: + traceback.print_exc() + return 'INTERNAL SERVER', 500 diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 503f628..2524b17 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -1,38 +1,68 @@ import time import traceback -from flask import jsonify, make_response, request +import simplejson as json +import ujson +from flask import Response, jsonify, request +from redis import Redis -from . import openai_bp -from ..cache import redis -from ..helpers.client import format_sillytavern_err +from llm_server.custom_redis import redis +from . import openai_bp, openai_model_bp from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler +from ..queue import priority_queue from ... import opts +from ...database.log_to_db import log_to_db from ...llm import get_token_count -from ...llm.openai.transform import build_openai_response, generate_oai_string +from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai +from ...llm.openai.transform import generate_oai_string, trim_string_to_fit # TODO: add rate-limit headers? @openai_bp.route('/completions', methods=['POST']) -def openai_completions(): +@openai_model_bp.route('//v1/completions', methods=['POST']) +def openai_completions(model_name=None): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: - try: - response, status_code = OobaRequestHandler(request).handle_request() - if status_code != 200: - return status_code + handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) + if handler.offline: + return return_invalid_model_err(model_name) + + if handler.cluster_backend_info['mode'] != 'vllm': + # TODO: implement other backends + raise NotImplementedError + + invalid_oai_err_msg = validate_oai(handler.request_json_body) + if invalid_oai_err_msg: + return invalid_oai_err_msg + handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode']) + + if opts.openai_silent_trim: + handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) + else: + # The handle_request() call below will load the prompt so we don't have + # to do anything else here. + pass + + handler.request_json_body['prompt'] = handler.prompt + + if not request_json_body.get('stream'): + invalid_oai_err_msg = validate_oai(request_json_body) + if invalid_oai_err_msg: + return invalid_oai_err_msg + response, status_code = handler.handle_request(return_ok=False) + if status_code == 429: + return handler.handle_ratelimited() output = response.json['results'][0]['text'] - # TODO: async/await - prompt_tokens = get_token_count(request_json_body['prompt']) - response_tokens = get_token_count(output) - running_model = redis.get('running_model', str, 'ERROR') + prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url) + response_tokens = get_token_count(output, handler.backend_url) + running_model = redis.get('running_model', 'ERROR', dtype=str) - response = make_response(jsonify({ + response = jsonify({ "id": f"cmpl-{generate_oai_string(30)}", "object": "text_completion", "created": int(time.time()), @@ -42,7 +72,7 @@ def openai_completions(): "text": output, "index": 0, "logprobs": None, - "finish_reason": None + "finish_reason": "stop" } ], "usage": { @@ -50,12 +80,141 @@ def openai_completions(): "completion_tokens": response_tokens, "total_tokens": prompt_tokens + response_tokens } - }), 200) + }) - stats = redis.get('proxy_stats', dict) - if stats: - response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] - return response - except Exception: - traceback.print_exc() - return 'Internal Server Error', 500 + # TODO: + # stats = redis.get('proxy_stats', dtype=dict) + # if stats: + # response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] + return response, 200 + else: + if not opts.enable_streaming: + return 'Streaming disabled', 403 + + request_valid, invalid_response = handler.validate_request() + if not request_valid: + return invalid_response + + handler.parameters, _ = handler.get_parameters() + handler.request_json_body = { + 'prompt': handler.request_json_body['prompt'], + 'model': handler.request_json_body['model'], + **handler.parameters + } + + invalid_oai_err_msg = validate_oai(handler.request_json_body) + if invalid_oai_err_msg: + return invalid_oai_err_msg + + if opts.openai_silent_trim: + handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']] + if not handler.prompt: + # Prevent issues on the backend. + return 'Invalid prompt', 400 + + start_time = time.time() + + request_valid, invalid_response = handler.validate_request() + if not request_valid: + return invalid_response + + event = None + if not handler.is_client_ratelimited(): + event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) + if not event: + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + None, + None, + handler.parameters, + request.headers, + 429, + request.url, + handler.backend_url, + ) + return handler.handle_ratelimited() + + try: + r_headers = dict(request.headers) + r_url = request.url + model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') + oai_string = generate_oai_string(30) + + _, stream_name, error_msg = event.wait() + if error_msg: + print('OAI failed to start streaming:', error_msg) + stream_name = None + return 'Request Timeout', 408 + + def generate(): + stream_redis = Redis(db=8) + generated_text = '' + try: + last_id = '0-0' + while True: + stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) + if not stream_data: + print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") + yield 'data: [DONE]\n\n' + else: + for stream_index, item in stream_data[0][1]: + last_id = stream_index + timestamp = int(stream_index.decode('utf-8').split('-')[0]) + data = ujson.loads(item[b'data']) + if data['error']: + print('OAI streaming encountered error') + yield 'data: [DONE]\n\n' + return + elif data['new']: + response = { + "id": f"cmpl-{oai_string}", + "object": "text_completion", + "created": timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": data['new'] + }, + "finish_reason": None + } + ] + } + generated_text = generated_text + data['new'] + yield f'data: {json.dumps(response)}\n\n' + elif data['completed']: + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + 200, + r_url, + handler.backend_url, + ) + return + except GeneratorExit: + # This should be triggered if a client disconnects early. + return + except Exception: + traceback.print_exc() + yield 'data: [DONE]\n\n' + finally: + if event: + redis.publish(f'notifications:{event.event_id}', 'canceled') + if stream_name: + stream_redis.delete(stream_name) + + return Response(generate(), mimetype='text/event-stream') + except Exception: + traceback.print_exc() + return 'INTERNAL SERVER', 500 diff --git a/llm_server/routes/openai/info.py b/llm_server/routes/openai/info.py index 54959ae..4fc578a 100644 --- a/llm_server/routes/openai/info.py +++ b/llm_server/routes/openai/info.py @@ -1,7 +1,7 @@ from flask import Response from . import openai_bp -from ..cache import flask_cache +from llm_server.custom_redis import flask_cache from ... import opts diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py index 47223e7..2ff0629 100644 --- a/llm_server/routes/openai/models.py +++ b/llm_server/routes/openai/models.py @@ -3,59 +3,58 @@ import traceback import requests from flask import jsonify +from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis from . import openai_bp -from ..cache import ONE_MONTH_SECONDS, flask_cache, redis from ..stats import server_start_time from ... import opts +from ...cluster.cluster_config import cluster_config, get_a_cluster_backend from ...helpers import jsonify_pretty -from ...llm.info import get_running_model +from ...llm.openai.transform import generate_oai_string @openai_bp.route('/models', methods=['GET']) @flask_cache.cached(timeout=60, query_string=True) def openai_list_models(): - model, error = get_running_model() - if not model: + model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model') + if not model_name: response = jsonify({ 'code': 502, 'msg': 'failed to reach backend', - 'type': error.__class__.__name__ }), 500 # return 500 so Cloudflare doesn't intercept us else: - running_model = redis.get('running_model', str, 'ERROR') + running_model = redis.get('running_model', 'ERROR', dtype=str) oai = fetch_openai_models() - r = [] + r = { + "object": "list", + "data": oai + } + # TODO: verify this works if opts.openai_expose_our_model: - r = [{ - "object": "list", - "data": [ + r["data"].insert(0, { + "id": running_model, + "object": "model", + "created": int(server_start_time.timestamp()), + "owned_by": opts.llm_middleware_name, + "permission": [ { "id": running_model, - "object": "model", + "object": "model_permission", "created": int(server_start_time.timestamp()), - "owned_by": opts.llm_middleware_name, - "permission": [ - { - "id": running_model, - "object": "model_permission", - "created": int(server_start_time.timestamp()), - "allow_create_engine": False, - "allow_sampling": False, - "allow_logprobs": False, - "allow_search_indices": False, - "allow_view": True, - "allow_fine_tuning": False, - "organization": "*", - "group": None, - "is_blocking": False - } - ], - "root": None, - "parent": None + "allow_create_engine": False, + "allow_sampling": False, + "allow_logprobs": False, + "allow_search_indices": False, + "allow_view": True, + "allow_fine_tuning": False, + "organization": "*", + "group": None, + "is_blocking": False } - ] - }] - response = jsonify_pretty(r + oai), 200 + ], + "root": None, + "parent": None + }) + response = jsonify_pretty(r), 200 return response @@ -64,7 +63,14 @@ def fetch_openai_models(): if opts.openai_api_key: try: response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10) - return response.json()['data'] + j = response.json()['data'] + + # The "modelperm" string appears to be user-specific, so we'll + # randomize it just to be safe. + for model in range(len(j)): + for p in range(len(j[model]['permission'])): + j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}' + return j except: traceback.print_exc() return [] diff --git a/llm_server/routes/openai/simulated.py b/llm_server/routes/openai/simulated.py index f626490..2dafedb 100644 --- a/llm_server/routes/openai/simulated.py +++ b/llm_server/routes/openai/simulated.py @@ -1,7 +1,7 @@ from flask import jsonify from . import openai_bp -from ..cache import ONE_MONTH_SECONDS, flask_cache +from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache from ...llm.openai.transform import generate_oai_string from ..stats import server_start_time @@ -17,7 +17,7 @@ def openai_organizations(): "id": f"org-{generate_oai_string(24)}", "created": int(server_start_time.timestamp()), "title": "Personal", - "name": "user-abcdefghijklmnopqrstuvwx", + "name": f"user-{generate_oai_string(24)}", "description": "Personal org for bobjoe@0.0.0.0", "personal": True, "is_default": True, diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index d97ea09..170eb77 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -1,14 +1,21 @@ import json +import re +import time import traceback from typing import Tuple from uuid import uuid4 import flask -from flask import jsonify +from flask import Response, jsonify, make_response from llm_server import opts +from llm_server.cluster.backend import get_model_choices +from llm_server.custom_redis import redis from llm_server.database.database import is_api_key_moderated -from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit +from llm_server.database.log_to_db import log_to_db +from llm_server.llm import get_token_count +from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err +from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit from llm_server.routes.request_handler import RequestHandler from llm_server.workers.moderator import add_moderation_task, get_results @@ -20,20 +27,37 @@ class OpenAIRequestHandler(RequestHandler): def handle_request(self) -> Tuple[flask.Response, int]: assert not self.used + if self.offline: + msg = return_invalid_model_err(self.selected_model) + print('OAI Offline:', msg) + return self.handle_error(msg) if opts.openai_silent_trim: - oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size) + oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url) else: oai_messages = self.request.json['messages'] self.prompt = transform_messages_to_prompt(oai_messages) + self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) + request_valid, invalid_response = self.validate_request() if not request_valid: return invalid_response - if opts.openai_api_key and is_api_key_moderated(self.token): + if not self.prompt: + # TODO: format this as an openai error message + return Response('Invalid prompt'), 400 + + # TODO: support Ooba backend + self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) + + invalid_oai_err_msg = validate_oai(self.request_json_body) + if invalid_oai_err_msg: + return invalid_oai_err_msg + + if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token): try: - # Gather the last message from the user and all preceeding system messages + # Gather the last message from the user and all preceding system messages msg_l = self.request.json['messages'].copy() msg_l.reverse() tag = uuid4() @@ -49,33 +73,40 @@ class OpenAIRequestHandler(RequestHandler): self.prompt = transform_messages_to_prompt(self.request.json['messages']) except Exception as e: print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}') - print(traceback.format_exc()) - - # Reconstruct the request JSON with the validated parameters and prompt. - self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE']) - if opts.openai_force_no_hashes: - self.parameters['stop'].append('### ') - - if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0: - self.request_json_body['top_p'] = 0.01 + traceback.print_exc() llm_request = {**self.parameters, 'prompt': self.prompt} (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) - model = self.request_json_body.get('model') if success: - return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code + return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code else: return backend_response, backend_response_status_code def handle_ratelimited(self, do_log: bool = True): - # TODO: return a simulated OpenAI error message - # Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another. - return 'Ratelimited', 429 + model_choices, default_model = get_model_choices() + default_model_info = model_choices[default_model] + w = int(default_model_info['estimated_wait']) if default_model_info['estimated_wait'] > 0 else 2 + response = jsonify({ + "error": { + "message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.", + "type": "rate_limit_exceeded", + "param": None, + "code": None + } + }) + response.headers['x-ratelimit-limit-requests'] = '2' + response.headers['x-ratelimit-remaining-requests'] = '0' + response.headers['x-ratelimit-reset-requests'] = f"{w}s" + + if do_log: + log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True) + + return response, 429 def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: - # TODO: return a simulated OpenAI error message + print('OAI Error:', error_msg) return jsonify({ "error": { "message": "Invalid request, check your parameters and try again.", @@ -84,3 +115,51 @@ class OpenAIRequestHandler(RequestHandler): "code": None } }), 400 + + def build_openai_response(self, prompt, response, model=None): + # Seperate the user's prompt from the context + x = prompt.split('### USER:') + if len(x) > 1: + prompt = re.sub(r'\n$', '', x[-1].strip(' ')) + + # Make sure the bot doesn't put any other instructions in its response + response = re.sub(ANTI_RESPONSE_RE, '', response) + response = re.sub(ANTI_CONTINUATION_RE, '', response) + + prompt_tokens = get_token_count(prompt, self.backend_url) + response_tokens = get_token_count(response, self.backend_url) + running_model = redis.get('running_model', 'ERROR', dtype=str) + + response = make_response(jsonify({ + "id": f"chatcmpl-{generate_oai_string(30)}", + "object": "chat.completion", + "created": int(time.time()), + "model": running_model if opts.openai_expose_our_model else model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": response, + }, + "logprobs": None, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": response_tokens, + "total_tokens": prompt_tokens + response_tokens + } + }), 200) + return response + + def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]: + self.parameters, parameters_invalid_msg = self.get_parameters() + if not self.parameters: + print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg) + return False, (Response('Invalid request, check your parameters and try again.'), 400) + invalid_oai_err_msg = validate_oai(self.parameters) + if invalid_oai_err_msg: + return False, invalid_oai_err_msg + # self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode']) + # If the parameters were invalid, let the superclass deal with it. + return super().validate_request(prompt, do_log) diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 84cc614..ee66580 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -1,12 +1,15 @@ -import json import pickle import time +from typing import Tuple from uuid import uuid4 +import ujson as json from redis import Redis from llm_server import opts -from llm_server.routes.cache import redis +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import RedisCustom, redis +from llm_server.database.database import get_token_ratelimit def increment_ip_count(client_ip: str, redis_key): @@ -20,24 +23,30 @@ def decrement_ip_count(client_ip: str, redis_key): class RedisPriorityQueue: - def __init__(self): - self.redis = Redis(host='localhost', port=6379, db=15) - self.pubsub = self.redis.pubsub() - self.pubsub.subscribe('events') + """ + A queue for a specific backend. + """ - def put(self, item, priority): - event = DataEvent() + def __init__(self, name, db: int = 12): + self.name = name + self.redis = RedisCustom(name, db=db) + + def put(self, item, priority: int, selected_model: str, do_stream: bool = False): + # TODO: remove this when we're sure nothing strange is happening + assert item is not None + assert priority is not None + assert selected_model is not None # Check if the IP is already in the dictionary and if it has reached the limit - ip_count = self.redis.hget('queued_ip_count', item[1]) - if ip_count: - ip_count = int(ip_count) - if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: - print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') + ip_count = self.get_ip_request_count(item[1]) + _, simultaneous_ip = get_token_ratelimit(item[2]) + if ip_count and int(ip_count) >= simultaneous_ip and priority != 0: + print(f'Rejecting request from {item[1]} - {ip_count} request queued.') return None # reject the request - self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority}) - self.increment_ip_count(item[1], 'queued_ip_count') + timestamp = time.time() + event = DataEvent() + self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority}) return event def get(self): @@ -45,31 +54,59 @@ class RedisPriorityQueue: data = self.redis.zpopmin('queue') if data: item = json.loads(data[0][0]) - client_ip = item[0][1] - self.decrement_ip_count(client_ip, 'queued_ip_count') return item time.sleep(0.1) # wait for something to be added to the queue - def increment_ip_count(self, client_ip: str, redis_key): - self.redis.hincrby(redis_key, client_ip, 1) - - def decrement_ip_count(self, client_ip: str, redis_key): - new_count = self.redis.hincrby(redis_key, client_ip, -1) - if new_count <= 0: - self.redis.hdel(redis_key, client_ip) - def __len__(self): return self.redis.zcard('queue') - def get_queued_ip_count(self, client_ip: str): - q = self.redis.hget('queued_ip_count', client_ip) - if not q: - return 0 - return 0 + def get_ip_request_count(self, client_ip: str): + """ + Get the number of requests in the queue from a specific IP. + This is a bit inefficient since we iterate over the entire queue, but + keeps the queue as a single point of truth instead of tracking a separate hashed + set which can get confusing. + If we run into slowdowns in the future, we should go back to the hashed set approach. + :param client_ip: + :return: + """ + start_time = time.time() + items = self.redis.zrange('queue', 0, -1) + count = 0 + for item in items: + item_data = json.loads(item) + if item_data[0][1] == client_ip: + count += 1 + elapsed_time = time.time() - start_time + if elapsed_time > 0.5: + raise Exception(f"!!! get_ip_request_count took {elapsed_time} seconds to execute !!!") + return count + + def flush(self): + self.redis.flush() + + def items(self): + return self.redis.zrange('queue', 0, -1) + + def cleanup(self): + now = time.time() + for item in self.items(): + item_data = json.loads(item) + timestamp = item_data[-2] + if now - timestamp > opts.backend_generate_request_timeout: + self.redis.zrem('queue', 0, item) + event_id = item_data[1] + event = DataEvent(event_id) + event.set((False, None, 'closed')) + print('Removed timed-out item from queue:', event_id) class DataEvent: - def __init__(self, event_id=None): + """ + Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol). + """ + + def __init__(self, event_id: str = None): self.event_id = event_id if event_id else str(uuid4()) self.redis = Redis(host='localhost', port=6379, db=14) self.pubsub = self.redis.pubsub() @@ -84,15 +121,89 @@ class DataEvent: return pickle.loads(item['data']) -priority_queue = RedisPriorityQueue() +def update_active_workers(key: str, operation: str): + if operation == 'incr': + redis.incr(f'active_gen_workers:{key}') + elif operation == 'decr': + redis.decr(f'active_gen_workers:{key}') + if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0: + redis.set(f'active_gen_workers:{key}', 0) -def incr_active_workers(): - redis.incr('active_gen_workers') +def incr_active_workers(selected_model: str, backend_url: str): + update_active_workers(selected_model, 'incr') + update_active_workers(backend_url, 'incr') -def decr_active_workers(): - redis.decr('active_gen_workers') - new_count = redis.get('active_gen_workers', int, 0) - if new_count < 0: - redis.set('active_gen_workers', 0) +def decr_active_workers(selected_model: str, backend_url: str): + update_active_workers(selected_model, 'decr') + update_active_workers(backend_url, 'decr') + + +class PriorityQueue: + """ + Helper class to wrangler all the different queues. + """ + + def __init__(self, backends: set = None): + """ + Only have to load the backends once. + :param backends: + """ + self.redis = Redis(host='localhost', port=6379, db=9) + if backends: + for item in backends: + self.redis.sadd('backends', item) + + def get_backends(self): + return {x.decode('utf-8') for x in self.redis.smembers('backends')} + + def get_queued_ip_count(self, client_ip: str): + count = 0 + for backend_url in self.get_backends(): + queue = RedisPriorityQueue(backend_url) + count += queue.get_ip_request_count(client_ip) + return count + + def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False): + queue = RedisPriorityQueue(backend_url) + return queue.put(item, priority, selected_model, do_stream) + + def activity(self): + lines = [] + status_redis = RedisCustom('worker_status') + for worker in status_redis.keys(): + lines.append((worker, status_redis.getp(worker))) + return sorted(lines) + + def len(self, model_name): + count = 0 + backends_with_models = set() + for k in self.get_backends(): + info = cluster_config.get_backend(k) + if info.get('model') == model_name: + backends_with_models.add(k) + for backend_url in backends_with_models: + count += len(RedisPriorityQueue(backend_url)) + return count + + def __len__(self): + count = 0 + p = set() + for backend_url in self.get_backends(): + queue = RedisPriorityQueue(backend_url) + p.add((backend_url, len(queue))) + count += len(queue) + return count + + def flush(self): + for k in self.redis.keys(): + q = json.loads(self.redis.get(k)) + q.flush() + self.redis.set(k, json.dumps(q)) + + def flush_db(self): + self.redis.flushdb() + + +priority_queue = PriorityQueue() diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 4b1f640..f4abfa6 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -5,23 +5,22 @@ import flask from flask import Response, request from llm_server import opts -from llm_server.database.conn import database -from llm_server.database.database import log_prompt +from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend +from llm_server.custom_redis import redis +from llm_server.database.database import get_token_ratelimit +from llm_server.database.log_to_db import log_to_db from llm_server.helpers import auto_set_base_client_api from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.routes.auth import parse_token -from llm_server.routes.cache import redis from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.queue import priority_queue -DEFAULT_PRIORITY = 9999 - class RequestHandler: - def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None): + def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None): self.request = incoming_request - self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' + # self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' # Routes need to validate it, here we just load it if incoming_json: @@ -34,11 +33,38 @@ class RequestHandler: self.start_time = time.time() self.client_ip = self.get_client_ip() self.token = self.get_auth_token() - self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit() - self.backend = get_backend() + self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token) self.parameters = None self.used = False - redis.zadd('recent_prompters', {self.client_ip: time.time()}) + + # This is null by default since most handlers need to transform the prompt in a specific way. + self.prompt = None + + self.selected_model = selected_model + self.backend_url = get_a_cluster_backend(selected_model) + self.cluster_backend_info = cluster_config.get_backend(self.backend_url) + + # Debug stuff + # if not self.cluster_backend_info.get('mode'): + # print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info) + # if not self.cluster_backend_info.get('model'): + # print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info) + # if not self.cluster_backend_info.get('model_config'): + # print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info) + + if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'): + self.offline = True + else: + self.offline = False + self.selected_model = self.cluster_backend_info['model'] + self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url) + if self.token and not self.token.startswith('SYSTEM__'): + # "recent_prompters" is only used for stats. + redis.zadd('recent_prompters', {self.client_ip: time.time()}) + + def check_online(self) -> bool: + self.cluster_backend_info = cluster_config.get_backend(self.backend_url) + return self.cluster_backend_info['online'] def get_auth_token(self): if self.request_json_body.get('X-API-KEY'): @@ -49,6 +75,8 @@ class RequestHandler: return parse_token(self.request.headers['Authorization']) def get_client_ip(self): + if self.request.headers.get('Llm-Connecting-Ip'): + return self.request.headers['Llm-Connecting-Ip'] if self.request.headers.get('X-Connecting-IP'): return self.request.headers.get('X-Connecting-IP') elif self.request.headers.get('Cf-Connecting-Ip'): @@ -58,26 +86,7 @@ class RequestHandler: else: return self.request.remote_addr - def get_token_ratelimit(self): - priority = DEFAULT_PRIORITY - simultaneous_ip = opts.simultaneous_requests_per_ip - if self.token: - cursor = database.cursor() - try: - cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,)) - result = cursor.fetchone() - if result: - priority, simultaneous_ip = result - if simultaneous_ip is None: - # No ratelimit for this token if null - simultaneous_ip = 999999999 - finally: - cursor.close() - return priority, simultaneous_ip - def get_parameters(self): - if self.request_json_body.get('max_tokens'): - self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens') parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) return parameters, parameters_invalid_msg @@ -119,7 +128,7 @@ class RequestHandler: backend_response = self.handle_error(combined_error_message, 'Validation Error') if do_log: - log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) + log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True) return False, backend_response return True, (None, 0) @@ -131,14 +140,18 @@ class RequestHandler: request_valid, invalid_response = self.validate_request(prompt, do_log=True) if not request_valid: return (False, None, None, 0), invalid_response - event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority) + event = priority_queue.put(self.backend_url, (llm_request, self.client_ip, self.token, self.parameters), self.token_priority, self.selected_model) else: event = None if not event: return (False, None, None, 0), self.handle_ratelimited() + # TODO: add wait timeout success, response, error_msg = event.wait() + if error_msg == 'closed': + return (False, None, None, 0), (self.handle_error('Request Timeout')[0], 408) + end_time = time.time() elapsed_time = end_time - self.start_time @@ -160,7 +173,17 @@ class RequestHandler: else: error_msg = error_msg.strip('.') + '.' backend_response = self.handle_error(error_msg) - log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) + log_to_db(ip=self.client_ip, + token=self.token, + prompt=prompt, + response=backend_response[0].data.decode('utf-8'), + gen_time=None, + parameters=self.parameters, + headers=dict(self.request.headers), + backend_response_code=response_status_code, + request_url=self.request.url, + backend_url=self.backend_url, + is_error=True) return (False, None, None, 0), backend_response # =============================================== @@ -180,7 +203,7 @@ class RequestHandler: if return_json_err: error_msg = 'The backend did not return valid JSON.' backend_response = self.handle_error(error_msg) - log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) + log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True) return (False, None, None, 0), backend_response # =============================================== @@ -189,22 +212,29 @@ class RequestHandler: return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) def is_client_ratelimited(self) -> bool: + if self.token_priority == 0: + return False + queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip)) x = redis.hget('processing_ips', self.client_ip) if x: processing_ip = int(x) else: processing_ip = 0 - if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0: - return False - else: - print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.') + + if queued_ip_count + processing_ip >= self.token_simultaneous_ip: + print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued') return True + else: + return False def handle_request(self) -> Tuple[flask.Response, int]: # Must include this in your child. - # if self.used: - # raise Exception('Can only use a RequestHandler object once.') + # assert not self.used + # if self.offline: + # msg = f'{self.selected_model} is not a valid model choice.' + # print(msg) + # return format_sillytavern_err(msg) raise NotImplementedError def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]: @@ -214,11 +244,11 @@ class RequestHandler: raise NotImplementedError -def get_backend(): - if opts.mode == 'oobabooga': - return OobaboogaBackend() - elif opts.mode == 'vllm': - return VLLMBackend() +def get_backend_handler(mode, backend_url: str): + if mode == 'oobabooga': + return OobaboogaBackend(backend_url) + elif mode == 'vllm': + return VLLMBackend(backend_url) else: raise Exception diff --git a/llm_server/routes/server_error.py b/llm_server/routes/server_error.py index fec3836..a6d6f99 100644 --- a/llm_server/routes/server_error.py +++ b/llm_server/routes/server_error.py @@ -1,3 +1,3 @@ def handle_server_error(e): - print(e) - return {'error': True}, 500 + print('Internal Error:', e) + return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500 diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py index a6e9e17..7f3b2fe 100644 --- a/llm_server/routes/stats.py +++ b/llm_server/routes/stats.py @@ -1,33 +1,11 @@ from datetime import datetime -from llm_server.routes.cache import redis - -# proompters_5_min = 0 -# concurrent_semaphore = Semaphore(concurrent_gens) +from llm_server.custom_redis import redis +from llm_server.helpers import round_up_base server_start_time = datetime.now() -# TODO: do I need this? -# def elapsed_times_cleanup(): -# global wait_in_queue_elapsed -# while True: -# current_time = time.time() -# with wait_in_queue_elapsed_lock: -# global wait_in_queue_elapsed -# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60] -# time.sleep(1) - - -def calculate_avg_gen_time(): - # Get the average generation time from Redis - average_generation_time = redis.get('average_generation_time') - if average_generation_time is None: - return 0 - else: - return float(average_generation_time) - - def get_total_proompts(): count = redis.get('proompts') if count is None: @@ -37,10 +15,27 @@ def get_total_proompts(): return count -def get_active_gen_workers(): - active_gen_workers = redis.get('active_gen_workers') - if active_gen_workers is None: - count = 0 +def get_active_gen_workers_model(selected_model: str = None): + return redis.get(f'active_gen_workers:{selected_model}', dtype=int, default=0) + + +def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers): + if active_gen_workers < concurrent_gens: + return 0 + elif active_gen_workers >= concurrent_gens: + # Calculate how long it will take to complete the currently running gens and the queued requests. + # If the proompters in the queue are equal to the number of workers, just use the calculated generation time. + # Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round + # that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally, + # Add gen_time_calc to the time to account for the currently running generations. + # This assumes that all active workers will finish at the same time, which is unlikely. + # Regardless, this is the most accurate estimate we can get without tracking worker elapsed times. + proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \ + else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc) + return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0 + elif proompters_in_queue == 0 and active_gen_workers == 0: + # No queue, no workers + return 0 else: - count = int(active_gen_workers) - return count + # No queue + return gen_time_calc diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py index 715288f..fcdc298 100644 --- a/llm_server/routes/v1/generate.py +++ b/llm_server/routes/v1/generate.py @@ -3,18 +3,18 @@ import traceback from flask import jsonify, request from . import bp -from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler -@bp.route('/generate', methods=['POST']) -def generate(): +@bp.route('/v1/generate', methods=['POST']) +@bp.route('//v1/generate', methods=['POST']) +def generate(model_name=None): request_valid_json, request_json_body = validate_json(request) if not request_valid_json or not request_json_body.get('prompt'): return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 else: - handler = OobaRequestHandler(request) + handler = OobaRequestHandler(request, selected_model=model_name) try: return handler.handle_request() except Exception: diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py index e144099..a9148b3 100644 --- a/llm_server/routes/v1/generate_stats.py +++ b/llm_server/routes/v1/generate_stats.py @@ -2,83 +2,30 @@ import time from datetime import datetime from llm_server import opts +from llm_server.cluster.backend import get_model_choices +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis from llm_server.database.database import get_distinct_ips_24h, sum_column -from llm_server.helpers import deep_sort, round_up_base -from llm_server.llm.info import get_running_model -from llm_server.netdata import get_power_states -from llm_server.routes.cache import redis -from llm_server.routes.queue import priority_queue -from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time +from llm_server.helpers import deep_sort +from llm_server.routes.stats import get_total_proompts, server_start_time -def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers): - if active_gen_workers < concurrent_gens: - return 0 - elif active_gen_workers >= concurrent_gens: - # Calculate how long it will take to complete the currently running gens and the queued requests. - # If the proompters in the queue are equal to the number of workers, just use the calculated generation time. - # Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round - # that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally, - # Add gen_time_calc to the time to account for the currently running generations. - # This assumes that all active workers will finish at the same time, which is unlikely. - # Regardless, this is the most accurate estimate we can get without tracking worker elapsed times. - proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \ - else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc) - return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0 - elif proompters_in_queue == 0 and active_gen_workers == 0: - # No queue, no workers - return 0 - else: - # No queue - return gen_time_calc - - -# TODO: have routes/__init__.py point to the latest API version generate_stats() - def generate_stats(regen: bool = False): if not regen: - c = redis.get('proxy_stats', dict) + c = redis.getp('proxy_stats') if c: return c - model_name, error = get_running_model() # will return False when the fetch fails - if isinstance(model_name, bool): - online = False - else: - online = True - redis.set('running_model', model_name) + model_choices, default_model = get_model_choices(regen=True) - # t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change - # if len(t) == 0: - # estimated_wait = 0 - # else: - # waits = [elapsed for end, elapsed in t] - # estimated_wait = int(sum(waits) / len(waits)) - - active_gen_workers = get_active_gen_workers() - proompters_in_queue = len(priority_queue) - - # This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM. - # estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0) - - average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0) - estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers) - - if opts.netdata_root: - netdata_stats = {} - power_states = get_power_states() - for gpu, power_state in power_states.items(): - netdata_stats[gpu] = { - 'power_state': power_state, - # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu'))) - } - else: - netdata_stats = {} - - base_client_api = redis.get('base_client_api', str) + base_client_api = redis.get('base_client_api', dtype=str) proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf')) output = { + 'models': { + 'choices': model_choices, + 'default': default_model, + }, 'stats': { 'proompters': { '5_min': proompters_5_min, @@ -86,39 +33,49 @@ def generate_stats(regen: bool = False): }, 'proompts_total': get_total_proompts() if opts.show_num_prompts else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None, - 'average_generation_elapsed_sec': int(average_generation_time), # 'estimated_avg_tps': estimated_avg_tps, 'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None, + 'num_backends': len(cluster_config.all()) if opts.show_backends else None, }, - 'online': online, 'endpoints': { 'blocking': f'https://{base_client_api}', 'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, }, - 'queue': { - 'processing': active_gen_workers, - 'queued': proompters_in_queue, - 'estimated_wait_sec': int(estimated_wait_sec), - }, 'timestamp': int(time.time()), 'config': { 'gatekeeper': 'none' if opts.auth_required is False else 'token', - 'context_size': opts.context_size, - 'concurrent': opts.concurrent_gens, - 'model': opts.manual_model_name if opts.manual_model_name else model_name, - 'mode': opts.mode, 'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip, + 'api_mode': opts.frontend_api_mode }, 'keys': { 'openaiKeys': '∞', 'anthropicKeys': '∞', }, - 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, - 'nvidia': netdata_stats + 'backends': {}, + 'online': len(model_choices) > 0 } + + # TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself + + if opts.show_backends: + for backend_url, v in cluster_config.all().items(): + backend_info = cluster_config.get_backend(backend_url) + if not backend_info['online']: + continue + backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None + output['backends'][backend_info['hash']] = { + 'uptime': backend_uptime, + 'max_tokens': backend_info['model_config'].get('max_position_embeddings', -1), + 'model': backend_info['model'], + 'mode': backend_info['mode'], + 'nvidia': backend_info['nvidia'], + 'priority': backend_info['priority'], + } + result = deep_sort(output) # It may take a bit to get the base client API, so don't cache until then. if base_client_api: - redis.set_dict('proxy_stats', result) # Cache with no expiry + redis.setp('proxy_stats', result) + return result diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 0fc8f40..3ed2f58 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -1,186 +1,200 @@ import json -import threading import time import traceback -from typing import Union +import ujson from flask import request +from redis import Redis -from ..cache import redis +from . import bp from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler -from ..queue import decr_active_workers, decrement_ip_count, priority_queue +from ..queue import priority_queue from ... import opts -from ...database.database import log_prompt -from ...llm.generator import generator -from ...llm.vllm import tokenize -from ...stream import sock +from ...custom_redis import redis +from ...database.log_to_db import log_to_db +from ...sock import sock -# TODO: have workers process streaming requests -# TODO: make sure to log the token as well (seems to be missing in the DB right now) +# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint. +# We solve this by splitting the routes -@sock.route('/api/v1/stream') -def stream(ws): - def send_err_and_quit(quitting_err_msg): - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': 0, - 'text': quitting_err_msg - })) - ws.send(json.dumps({ - 'event': 'stream_end', - 'message_num': 1 - })) - ws.close() - log_in_bg(quitting_err_msg, is_error=True) +@bp.route('/v1/stream') +@bp.route('//v1/stream') +def stream(model_name=None): + return 'This is a websocket endpoint.', 400 - def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None): - def background_task_exception(): - generated_tokens = tokenize(generated_text_bg) - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error) +@sock.route('/v1/stream', bp=bp) +def stream_without_model(ws): + do_stream(ws, model_name=None) - # TODO: use async/await instead of threads - thread = threading.Thread(target=background_task_exception) - thread.start() - thread.join() - if not opts.enable_streaming: - return 'Streaming is disabled', 401 +@sock.route('//v1/stream', bp=bp) +def stream_with_model(ws, model_name=None): + do_stream(ws, model_name) - r_headers = dict(request.headers) - r_url = request.url - message_num = 0 - while ws.connected: - message = ws.receive() - request_valid_json, request_json_body = validate_json(message) - if not request_valid_json or not request_json_body.get('prompt'): - return 'Invalid JSON', 400 - else: - if opts.mode != 'vllm': - # TODO: implement other backends - raise NotImplementedError - auth_failure = require_api_key(request_json_body) - if auth_failure: - return auth_failure +def do_stream(ws, model_name): + event_id = None + try: + def send_err_and_quit(quitting_err_msg): + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': 0, + 'text': quitting_err_msg + })) + ws.send(json.dumps({ + 'event': 'stream_end', + 'message_num': 1 + })) + ws.close() + log_to_db(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=quitting_err_msg, + gen_time=None, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url, + response_tokens=None, + is_error=True + ) - handler = OobaRequestHandler(request, request_json_body) - generated_text = '' - input_prompt = request_json_body['prompt'] - response_status_code = 0 - start_time = time.time() + if not opts.enable_streaming: + return 'Streaming disabled', 403 - err_msg = None - if handler.is_client_ratelimited(): - r, _ = handler.handle_ratelimited(do_log=False) - err_msg = r.json['results'][0]['text'] + r_headers = dict(request.headers) + r_url = request.url + message_num = 0 + + while ws.connected: + message = ws.receive() + request_valid_json, request_json_body = validate_json(message) + + if not request_valid_json or not request_json_body.get('prompt'): + return 'Invalid JSON', 400 else: - request_valid, invalid_response = handler.validate_request(prompt=input_prompt) - if not request_valid: - err_msg = invalid_response[0].json['results'][0]['text'] - if err_msg: - send_err_and_quit(err_msg) - return + # We have to do auth ourselves since the details are sent in the message. + auth_failure = require_api_key(request_json_body) + if auth_failure: + return auth_failure - llm_request = { - **handler.parameters, - 'prompt': input_prompt, - 'stream': True, - } - - # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority) - if not event: - r, _ = handler.handle_ratelimited() - err_msg = r.json['results'][0]['text'] - send_err_and_quit(err_msg) - return - try: - response = generator(llm_request) - if not response: - error_msg = 'Failed to reach backend while streaming.' - print('Streaming failed:', error_msg) - msg = handler.handle_error(error_msg)[0].json['results'][0]['text'] + handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body) + if handler.offline: + msg = f'{handler.selected_model} is not a valid model choice.' + print(msg) ws.send(json.dumps({ 'event': 'text_stream', - 'message_num': message_num, + 'message_num': 0, 'text': msg })) + return + + if handler.cluster_backend_info['mode'] != 'vllm': + # TODO: implement other backends + raise NotImplementedError + + input_prompt = request_json_body['prompt'] + response_status_code = 0 + start_time = time.time() + + err_msg = None + if handler.is_client_ratelimited(): + r, _ = handler.handle_ratelimited(do_log=False) + err_msg = r.json['results'][0]['text'] else: - # Be extra careful when getting attributes from the response object - try: - response_status_code = response.status_code - except: - response_status_code = 0 + request_valid, invalid_response = handler.validate_request(prompt=input_prompt) + if not request_valid: + err_msg = invalid_response[0].json['results'][0]['text'] + if err_msg: + send_err_and_quit(err_msg) + return - partial_response = b'' + handler.parameters, _ = handler.get_parameters() + handler.prompt = input_prompt + handler.request_json_body = { + 'prompt': handler.prompt, + **handler.parameters + } - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(input_prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue - try: - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': new - })) - except: - # The has client closed the stream. - if request: - request.close() - ws.close() - end_time = time.time() - elapsed_time = end_time - start_time - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) - return + event = None + if not handler.is_client_ratelimited(): + event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) + if not event: + r = handler.handle_ratelimited() + send_err_and_quit(r[0].data) + return + event_id = event.event_id + _, stream_name, error_msg = event.wait() + if error_msg: + print('Stream failed to start streaming:', error_msg) + ws.close(reason=1014, message='Request Timeout') + return + + stream_redis = Redis(db=8) + generated_text = '' + + try: + last_id = '0-0' # The ID of the last entry we read. + while True: + stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) + if not stream_data: + print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") + return + else: + for stream_index, item in stream_data[0][1]: + last_id = stream_index + data = ujson.loads(item[b'data']) + if data['error']: + print(data['error']) + send_err_and_quit('Encountered exception while streaming.') + return + elif data['new']: + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': data['new'] + })) message_num += 1 - partial_response = b'' # Reset the partial response - - # If there is no more data, break the loop - if not chunk: - break - - end_time = time.time() - elapsed_time = end_time - start_time - log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code) - except: - traceback.print_exc() - generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text'] - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': generated_text - })) - if request: - request.close() - ws.close() - log_in_bg(generated_text, is_error=True, status_code=response_status_code) - return - finally: - # The worker incremented it, we'll decrement it. - decrement_ip_count(handler.client_ip, 'processing_ips') - decr_active_workers() - try: - ws.send(json.dumps({ - 'event': 'stream_end', - 'message_num': message_num - })) - except: - # The client closed the stream. - end_time = time.time() - elapsed_time = end_time - start_time - log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) - ws.close() # this is important if we encountered and error and exited early. + generated_text = generated_text + data['new'] + elif data['completed']: + return + except: + send_err_and_quit('Encountered exception while streaming.') + traceback.print_exc() + finally: + try: + ws.send(json.dumps({ + 'event': 'stream_end', + 'message_num': message_num + })) + except: + # The client closed the stream. + pass + if stream_name: + stream_redis.delete(stream_name) + end_time = time.time() + elapsed_time = end_time - start_time + log_to_db(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=elapsed_time, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url + ) + finally: + if event_id: + redis.publish(f'notifications:{event_id}', 'canceled') + try: + # Must close the connection or greenlets will complain. + ws.close() + except: + pass diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py index 2091118..342921e 100644 --- a/llm_server/routes/v1/info.py +++ b/llm_server/routes/v1/info.py @@ -2,22 +2,16 @@ import time from flask import jsonify, request +from llm_server.custom_redis import flask_cache from . import bp -from ..auth import requires_auth -from ..cache import flask_cache from ... import opts -from ...llm.info import get_running_model +from ...cluster.backend import get_backends_from_model, is_valid_model +from ...cluster.cluster_config import cluster_config, get_a_cluster_backend -# @bp.route('/info', methods=['GET']) -# # @cache.cached(timeout=3600, query_string=True) -# def get_info(): -# # requests.get() -# return 'yes' - - -@bp.route('/model', methods=['GET']) -def get_model(): +@bp.route('/v1/model', methods=['GET']) +@bp.route('//v1/model', methods=['GET']) +def get_model(model_name=None): # We will manage caching ourself since we don't want to cache # when the backend is down. Also, Cloudflare won't cache 500 errors. cache_key = 'model_cache::' + request.url @@ -26,24 +20,21 @@ def get_model(): if cached_response: return cached_response - model_name, error = get_running_model() if not model_name: + model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model') + + if not is_valid_model(model_name): response = jsonify({ - 'code': 502, - 'msg': 'failed to reach backend', - 'type': error.__class__.__name__ - }), 500 # return 500 so Cloudflare doesn't intercept us + 'code': 400, + 'msg': 'Model does not exist.', + }), 400 else: + num_backends = len(get_backends_from_model(model_name)) response = jsonify({ 'result': opts.manual_model_name if opts.manual_model_name else model_name, + 'model_backend_count': num_backends, 'timestamp': int(time.time()) }), 200 flask_cache.set(cache_key, response, timeout=60) return response - - -@bp.route('/backend', methods=['GET']) -@requires_auth -def get_backend(): - return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200 diff --git a/llm_server/routes/v1/proxy.py b/llm_server/routes/v1/proxy.py index 4349526..6e3708e 100644 --- a/llm_server/routes/v1/proxy.py +++ b/llm_server/routes/v1/proxy.py @@ -1,8 +1,10 @@ from flask import jsonify +from llm_server.custom_redis import flask_cache from . import bp from .generate_stats import generate_stats -from ..cache import flask_cache +from ..auth import requires_auth +from ...cluster.cluster_config import cluster_config, get_backends from ...helpers import jsonify_pretty @@ -10,3 +12,14 @@ from ...helpers import jsonify_pretty @flask_cache.cached(timeout=5, query_string=True) def get_stats(): return jsonify_pretty(generate_stats()) + + +@bp.route('/backends', methods=['GET']) +@requires_auth +def get_backend(): + online, offline = get_backends() + result = {} + for i in online + offline: + info = cluster_config.get_backend(i) + result[info['hash']] = info + return jsonify(result), 200 diff --git a/llm_server/stream.py b/llm_server/sock.py similarity index 77% rename from llm_server/stream.py rename to llm_server/sock.py index 8ac2fc1..2f1a17d 100644 --- a/llm_server/stream.py +++ b/llm_server/sock.py @@ -3,6 +3,6 @@ from flask_sock import Sock sock = Sock() -def init_socketio(app): +def init_wssocket(app): global sock sock.init_app(app) diff --git a/llm_server/workers/app.py b/llm_server/workers/app.py deleted file mode 100644 index fda6fb3..0000000 --- a/llm_server/workers/app.py +++ /dev/null @@ -1,35 +0,0 @@ -from threading import Thread - -from .blocking import start_workers -from .main import main_background_thread -from .moderator import start_moderation_workers -from .printer import console_printer -from .recent import recent_prompters_thread -from .threads import cache_stats -from .. import opts - - -def start_background(): - start_workers(opts.concurrent_gens) - - t = Thread(target=main_background_thread) - t.daemon = True - t.start() - print('Started the main background thread.') - - start_moderation_workers(opts.openai_moderation_workers) - - t = Thread(target=cache_stats) - t.daemon = True - t.start() - print('Started the stats cacher.') - - t = Thread(target=recent_prompters_thread) - t.daemon = True - t.start() - print('Started the recent proompters thread.') - - t = Thread(target=console_printer) - t.daemon = True - t.start() - print('Started the console printer.') diff --git a/llm_server/workers/blocking.py b/llm_server/workers/blocking.py deleted file mode 100644 index 27b0815..0000000 --- a/llm_server/workers/blocking.py +++ /dev/null @@ -1,51 +0,0 @@ -import threading -import time - -from llm_server import opts -from llm_server.llm.generator import generator -from llm_server.routes.cache import redis -from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue - - -def worker(): - while True: - need_to_wait() - (request_json_body, client_ip, token, parameters), event_id = priority_queue.get() - need_to_wait() - - increment_ip_count(client_ip, 'processing_ips') - incr_active_workers() - - if not request_json_body: - # This was a dummy request from the websocket handler. - # We're going to let the websocket handler decrement processing_ips and active_gen_workers. - continue - - try: - success, response, error_msg = generator(request_json_body) - event = DataEvent(event_id) - event.set((success, response, error_msg)) - finally: - decrement_ip_count(client_ip, 'processing_ips') - decr_active_workers() - - -def start_workers(num_workers: int): - i = 0 - for _ in range(num_workers): - t = threading.Thread(target=worker) - t.daemon = True - t.start() - i += 1 - print(f'Started {i} inference workers.') - - -def need_to_wait(): - # We need to check the number of active workers since the streaming endpoint may be doing something. - active_workers = redis.get('active_gen_workers', int, 0) - s = time.time() - while active_workers >= opts.concurrent_gens: - time.sleep(0.01) - e = time.time() - if e - s > 0.5: - print(f'Worker was delayed {e - s} seconds.') diff --git a/llm_server/workers/cleaner.py b/llm_server/workers/cleaner.py new file mode 100644 index 0000000..95a6a78 --- /dev/null +++ b/llm_server/workers/cleaner.py @@ -0,0 +1,32 @@ +import time + +from redis import Redis + +from llm_server.workers.inferencer import STREAM_NAME_PREFIX + + +# NOT NEEDED + +def cleaner(): + r = Redis(db=8) + stream_info = {} + + while True: + all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*') + processed_streams = [] + for stream in all_streams: + stream = stream.decode() + current_size = r.xlen(stream) + + # If the stream is new or its size has changed, update the size and time in the dictionary + if stream not in stream_info or current_size != stream_info[stream]['size']: + stream_info[stream] = {'size': current_size, 'time': time.time()} + processed_streams.append(stream) + else: + # If the size hasn't changed for 5 minutes, delete the stream + if time.time() - stream_info[stream]['time'] >= 300: + r.delete(stream) + print(f"Stream '{stream}' deleted due to inactivity.") + del stream_info[stream] + + time.sleep(60) diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py new file mode 100644 index 0000000..21e45d0 --- /dev/null +++ b/llm_server/workers/inferencer.py @@ -0,0 +1,160 @@ +import json +import threading +import time +import traceback +from uuid import uuid4 + +import ujson +from redis import Redis + +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import RedisCustom, redis +from llm_server.llm.generator import generator +from llm_server.logging import create_logger +from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count + +stream_redis = Redis(db=8) + +STREAM_NAME_PREFIX = 'stream' + + +def check_cancellation(event, event_id): + """ + This thread checks the pub/sub channel in the background so the main process + isn't bogged down with Redis calls. Otherwise, the main process slows down to 1 token/sec. + :param event: + :param event_id: + :return: + """ + pubsub = redis.pubsub() + pubsub.subscribe(f'notifications:{event_id}') + while not event.is_set(): + message = pubsub.get_message() + if message and message['data'] == b'canceled': + event.set() + time.sleep(0.5) # check every half second + + +def get_stream_name(name: str): + return f'{STREAM_NAME_PREFIX}:{name}' + + +def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str): + logger = create_logger('inferencer') + prompt = msg_to_backend['prompt'] + stream_name = get_stream_name(stream_name) + stream_redis.delete(get_stream_name(stream_name)) # be extra sure + event = threading.Event() + threading.Thread(target=check_cancellation, args=(event, event_id)).start() + try: + response = generator(msg_to_backend, backend_url) + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + # If there is no more data, break the loop + if not chunk: + break + if event.is_set(): + logger.debug('Client canceled generation') + response.close() + return + + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue + stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})}) + except AttributeError as e: + if str(e) == "'bool' object has no attribute 'iter_content'": + # We don't care about these errors. + logger.debug('failed to stream from backend - no response') + else: + raise + except Exception as e: + stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})}) + raise # We won't handle the exception here. + finally: + # Publish final message to Redis stream + stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})}) + event.set() # stop the cancellation checking thread + + +# +def worker(backend_url): + logger = create_logger('inferencer') + status_redis = RedisCustom('worker_status') + worker_id = str(uuid4()) + status_redis.setp(str(worker_id), None) + redis_queue = RedisPriorityQueue(backend_url) + while True: + status_redis.setp(str(worker_id), 'waiting...') + (request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get() + event = DataEvent(event_id) + + try: + backend_info = cluster_config.get_backend(backend_url) + except: + # This is not a critical error because it usually means that the backend is + # offline and this backend is in a state of transition from online to offline. + logger.debug(f'got an exception while getting info for backend {backend_url} - ', traceback.format_exc()) + event.set((False, None, 'exception')) + continue + + if not backend_info['online']: + event.set((False, None, 'canceled')) + continue + + if not selected_model: + selected_model = backend_info['model'] + + logger.debug(f"Starting using {backend_url} and {selected_model}. Online: {backend_info['online']}. Streaming: {do_stream}") + + try: + stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams + increment_ip_count(client_ip, 'processing_ips') + incr_active_workers(selected_model, backend_url) + + if do_stream: + status_redis.setp(str(worker_id), ('streaming', client_ip)) + + # Return the name of the stream that the slave should connect to. + event.set((True, get_stream_name(worker_id), None)) + + msg_to_backend = { + **parameters, + 'prompt': request_json_body['prompt'], + 'stream': True, + } + inference_do_stream(worker_id, msg_to_backend, backend_url, event_id) + else: + # Normal inference (not streaming). + status_redis.setp(str(worker_id), ('generating', client_ip)) + success, response, error_msg = generator(request_json_body, backend_url) + event.set((success, response, error_msg)) + except: + logger.error(traceback.format_exc()) + event.set((False, None, 'exception')) + finally: + decrement_ip_count(client_ip, 'processing_ips') + decr_active_workers(selected_model, backend_url) + status_redis.setp(str(worker_id), None) + + +def start_workers(cluster: dict): + logger = create_logger('inferencer') + i = 0 + for item in cluster: + for _ in range(item['concurrent_gens']): + t = threading.Thread(target=worker, args=(item['backend_url'],)) + t.daemon = True + t.start() + i += 1 + logger.info(f'Started {i} inference workers.') diff --git a/llm_server/workers/logger.py b/llm_server/workers/logger.py new file mode 100644 index 0000000..eada969 --- /dev/null +++ b/llm_server/workers/logger.py @@ -0,0 +1,31 @@ +import pickle +import traceback + +import redis + +from llm_server.database.database import do_db_log + + +def db_logger(): + """ + We don't want the logging operation to be blocking, so we will use a background worker + to do the logging. + :return: + """ + + r = redis.Redis(host='localhost', port=6379, db=3) + p = r.pubsub() + p.subscribe('database-logger') + + for message in p.listen(): + try: + if message['type'] == 'message': + data = pickle.loads(message['data']) + function_name = data['function'] + args = data['args'] + kwargs = data['kwargs'] + + if function_name == 'log_prompt': + do_db_log(*args, **kwargs) + except: + traceback.print_exc() diff --git a/llm_server/workers/main.py b/llm_server/workers/main.py deleted file mode 100644 index 747f699..0000000 --- a/llm_server/workers/main.py +++ /dev/null @@ -1,56 +0,0 @@ -import time -from threading import Thread - -from llm_server import opts -from llm_server.database.database import weighted_average_column_for_model -from llm_server.llm.info import get_running_model -from llm_server.routes.cache import redis - - -def main_background_thread(): - redis.set('average_generation_elapsed_sec', 0) - redis.set('estimated_avg_tps', 0) - redis.set('average_output_tokens', 0) - redis.set('backend_online', 0) - redis.set_dict('backend_info', {}) - - while True: - # TODO: unify this - if opts.mode == 'oobabooga': - running_model, err = get_running_model() - if err: - print(err) - redis.set('backend_online', 0) - else: - redis.set('running_model', running_model) - redis.set('backend_online', 1) - elif opts.mode == 'vllm': - running_model, err = get_running_model() - if err: - print(err) - redis.set('backend_online', 0) - else: - redis.set('running_model', running_model) - redis.set('backend_online', 1) - else: - raise Exception - - # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 - # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. - average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 - if average_generation_elapsed_sec: # returns None on exception - redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) - - # overall = average_column_for_model('prompts', 'generation_time', opts.running_model) - # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') - - average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 - if average_generation_elapsed_sec: - redis.set('average_output_tokens', average_output_tokens) - - # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) - # print(f'Weighted: {average_output_tokens}, overall: {overall}') - - estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero - redis.set('estimated_avg_tps', estimated_avg_tps) - time.sleep(60) diff --git a/llm_server/workers/mainer.py b/llm_server/workers/mainer.py new file mode 100644 index 0000000..d342f4b --- /dev/null +++ b/llm_server/workers/mainer.py @@ -0,0 +1,57 @@ +import time + +import requests + +from llm_server import opts +from llm_server.cluster.cluster_config import cluster_config, get_backends +from llm_server.custom_redis import redis +from llm_server.database.database import weighted_average_column_for_model +from llm_server.llm.info import get_info +from llm_server.routes.queue import RedisPriorityQueue, priority_queue + + +def main_background_thread(): + while True: + online, offline = get_backends() + for backend_url in online: + backend_info = cluster_config.get_backend(backend_url) + backend_mode = backend_info['mode'] + backend_info = get_info(backend_url, backend_mode) + running_model = backend_info.get('model') + if not running_model: + continue + + average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode) + if average_generation_elapsed_sec: # returns None on exception + cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec) + if average_output_tokens: + cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens) + if average_generation_elapsed_sec and average_output_tokens: + cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps) + + if opts.background_homepage_cacher: + try: + base_client_api = redis.get('base_client_api', dtype=str) + r = requests.get('https://' + base_client_api, timeout=5) + except Exception as e: + print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}') + + backends = priority_queue.get_backends() + for backend_url in backends: + queue = RedisPriorityQueue(backend_url) + queue.cleanup() + + time.sleep(30) + + +def calc_stats_for_backend(backend_url, running_model, backend_mode): + # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 + # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. + average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', + running_model, backend_mode, backend_url, exclude_zeros=True, + include_system_tokens=opts.include_system_tokens_in_stats) or 0 + average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', + running_model, backend_mode, backend_url, exclude_zeros=True, + include_system_tokens=opts.include_system_tokens_in_stats) or 0 + estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero + return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps diff --git a/llm_server/workers/moderator.py b/llm_server/workers/moderator.py index 4457d05..6d56eee 100644 --- a/llm_server/workers/moderator.py +++ b/llm_server/workers/moderator.py @@ -1,10 +1,13 @@ import json import threading +import time import traceback import redis as redis_redis +from llm_server import opts from llm_server.llm.openai.moderation import check_moderation_endpoint +from llm_server.logging import create_logger redis_moderation = redis_redis.Redis() @@ -16,36 +19,43 @@ def start_moderation_workers(num_workers): t.daemon = True t.start() i += 1 - print(f'Started {i} moderation workers.') -def moderation_worker(): - while True: - result = redis_moderation.blpop('queue:msgs_to_check') - try: - msg, tag = json.loads(result[1]) - _, categories = check_moderation_endpoint(msg) - redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories))) - except: - print(result) - traceback.print_exc() - continue - - -def add_moderation_task(msg, tag): - redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag)))) - +# TODO: don't use UUID tags to identify items. Use native redis def get_results(tag, num_tasks): - tag = str(tag) # Required for comparison with Redis results. + tag = str(tag) # Cast a UUID4 to a string. flagged_categories = set() num_results = 0 + start_time = time.time() while num_results < num_tasks: - result = redis_moderation.blpop('queue:flagged_categories') + result = redis_moderation.blpop(['queue:flagged_categories'], timeout=opts.openai_moderation_timeout) + if result is None: + break # Timeout occurred, break the loop. result_tag, categories = json.loads(result[1]) if result_tag == tag: if categories: for item in categories: flagged_categories.add(item) num_results += 1 + if time.time() - start_time > opts.openai_moderation_timeout: + logger.warning('Timed out waiting for result from moderator') + break return list(flagged_categories) + + +def moderation_worker(): + logger = create_logger('moderator') + while True: + result = redis_moderation.blpop(['queue:msgs_to_check']) + try: + msg, tag = json.loads(result[1]) + _, categories = check_moderation_endpoint(msg) + redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories))) + except: + logger.error(traceback.format_exc()) + continue + + +def add_moderation_task(msg, tag): + redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag)))) diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py index cb0f032..deb3246 100644 --- a/llm_server/workers/printer.py +++ b/llm_server/workers/printer.py @@ -1,25 +1,34 @@ -import logging import time +import traceback -from llm_server.routes.cache import redis +from llm_server.cluster.backend import get_running_models +from llm_server.cluster.cluster_config import cluster_config +from llm_server.custom_redis import redis +from llm_server.logging import create_logger from llm_server.routes.queue import priority_queue -logger = logging.getLogger('console_printer') -if not logger.handlers: - handler = logging.StreamHandler() - handler.setLevel(logging.INFO) - logger.setLevel(logging.INFO) - formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s") - handler.setFormatter(formatter) - logger.addHandler(handler) - def console_printer(): + logger = create_logger('console_printer') time.sleep(3) while True: - processing = redis.hkeys('processing_ips') - processing_count = 0 - for ip in processing: - processing_count += int(redis.hget('processing_ips', ip)) - logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}') + try: + processing = redis.keys('active_gen_workers:http*') # backends always start with http + processing_count = 0 + if len(processing): + for k in processing: + processing_count += redis.get(k, default=0, dtype=int) + backends = [k for k, v in cluster_config.all().items() if v['online']] + activity = priority_queue.activity() + + # Calculate the queue size the same way it's done on the stats. + queue_size = 0 + running_models = get_running_models() + for model in running_models: + queue_size += priority_queue.len(model) + + # Active Workers and Processing should read the same. If not, that's an issue. + logger.info(f'Active Workers: {len([i for i in activity if (i[1] and i[1] != "waiting...")])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {len(backends)}') + except: + logger.error(traceback.format_exc()) time.sleep(10) diff --git a/llm_server/workers/recent.py b/llm_server/workers/recenter.py similarity index 81% rename from llm_server/workers/recent.py rename to llm_server/workers/recenter.py index d378a87..c6158d6 100644 --- a/llm_server/workers/recent.py +++ b/llm_server/workers/recenter.py @@ -1,6 +1,6 @@ import time -from llm_server.routes.cache import redis +from llm_server.custom_redis import redis def recent_prompters_thread(): diff --git a/llm_server/workers/threader.py b/llm_server/workers/threader.py new file mode 100644 index 0000000..542a630 --- /dev/null +++ b/llm_server/workers/threader.py @@ -0,0 +1,58 @@ +import time +from threading import Thread + +from llm_server import opts +from llm_server.cluster.worker import cluster_worker +from llm_server.logging import create_logger +from llm_server.routes.v1.generate_stats import generate_stats +from llm_server.workers.inferencer import start_workers +from llm_server.workers.logger import db_logger +from llm_server.workers.mainer import main_background_thread +from llm_server.workers.moderator import start_moderation_workers +from llm_server.workers.printer import console_printer +from llm_server.workers.recenter import recent_prompters_thread + + +def cache_stats(): + while True: + generate_stats(regen=True) + time.sleep(5) + + +def start_background(): + logger = create_logger('threader') + start_workers(opts.cluster) + + t = Thread(target=main_background_thread) + t.daemon = True + t.start() + logger.info('Started the main background thread.') + + num_moderators = opts.cluster_workers * 3 + start_moderation_workers(num_moderators) + logger.info(f'Started {num_moderators} moderation workers.') + + t = Thread(target=cache_stats) + t.daemon = True + t.start() + logger.info('Started the stats cacher.') + + t = Thread(target=recent_prompters_thread) + t.daemon = True + t.start() + logger.info('Started the recent proompters thread.') + + t = Thread(target=console_printer) + t.daemon = True + t.start() + logger.info('Started the console logger.infoer.') + + t = Thread(target=cluster_worker) + t.daemon = True + t.start() + logger.info('Started the cluster worker.') + + t = Thread(target=db_logger) + t.daemon = True + t.start() + logger.info('Started background logger.') diff --git a/llm_server/workers/threads.py b/llm_server/workers/threads.py deleted file mode 100644 index d1c5183..0000000 --- a/llm_server/workers/threads.py +++ /dev/null @@ -1,9 +0,0 @@ -import time - -from llm_server.routes.v1.generate_stats import generate_stats - - -def cache_stats(): - while True: - generate_stats(regen=True) - time.sleep(5) diff --git a/other/gradio/gradio_chat.py b/other/gradio/gradio_chat.py new file mode 100644 index 0000000..179e748 --- /dev/null +++ b/other/gradio/gradio_chat.py @@ -0,0 +1,103 @@ +import os +import sys +import time +import traceback +import warnings +from threading import Thread + +import gradio as gr +import openai +import requests + +warnings.filterwarnings("ignore") + +API_BASE = os.getenv('API_BASE') +if not API_BASE: + print('Must set the secret variable API_BASE to your https://your-site/api') + sys.exit(1) +API_BASE = API_BASE.strip('/') + +APP_TITLE = os.getenv('APP_TITLE') +PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE') +TRACKING_CODE = os.getenv('TRACKING_CODE') + + +def background(): + while True: + previous = openai.api_base + try: + r = requests.get(API_BASE + '/stats').json() + if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys(): + openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1' + else: + openai.api_base = API_BASE + '/openai/v1' + except: + traceback.print_exc() + openai.api_base = API_BASE + '/openai/v1' + if openai.api_base != previous: + print('Set primary model to', openai.api_base) + time.sleep(10) + + +if PRIMARY_MODEL_CHOICE: + t = Thread(target=background) + t.daemon = True + t.start() + print('Started the background thread.') + +# A system prompt can be injected into the very first spot in the context. +# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE, +# the content in CONTEXT_TRIGGER_INJECTION will be injected. +# Setting CONTEXT_TRIGGER_PHRASE will also add it to the selectable examples. +CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE') +CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION') + +openai.api_key = 'null' +openai.api_base = API_BASE + '/openai/v1' + + +def stream_response(prompt, history): + messages = [] + do_injection = False + for human, assistant in history: + messages.append({'role': 'user', 'content': str(human)}) + messages.append({'role': 'assistant', 'content': str(assistant)}) + + if CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in human: + do_injection = True + messages.append({'role': 'user', 'content': prompt}) + + if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt): + messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION}) + + try: + response = openai.ChatCompletion.create( + model='0', + messages=messages, + temperature=0, + max_tokens=300, + stream=True, + headers={'LLM-Source': 'huggingface-demo'} + ) + except Exception: + raise gr.Error("Failed to reach inference endpoint.") + + message = '' + for chunk in response: + if len(chunk['choices'][0]['delta']) != 0: + message += chunk['choices'][0]['delta']['content'] + yield message + + +examples = ["hello"] +if CONTEXT_TRIGGER_PHRASE: + examples.insert(0, CONTEXT_TRIGGER_PHRASE) + +with gr.Blocks(analytics_enabled=False) as demo: + gr.ChatInterface(stream_response, examples=examples, title=APP_TITLE, analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}') + + if TRACKING_CODE: + print('Inserting tracking code') + gr.HTML(TRACKING_CODE) + +demo.queue(concurrency_count=1, api_open=False).launch(show_api=False) diff --git a/other/gradio/requirements.txt b/other/gradio/requirements.txt new file mode 100644 index 0000000..eb4baac --- /dev/null +++ b/other/gradio/requirements.txt @@ -0,0 +1,3 @@ +gradio +openai +requests \ No newline at end of file diff --git a/gunicorn.py b/other/gunicorn.py similarity index 60% rename from gunicorn.py rename to other/gunicorn.py index 30f9274..099e9ce 100644 --- a/gunicorn.py +++ b/other/gunicorn.py @@ -1,3 +1,8 @@ +""" +This file is used to run certain tasks when the HTTP server starts. +It's located here so it doesn't get imported with daemon.py +""" + try: import gevent.monkey diff --git a/other/nginx-site.conf b/other/nginx-site.conf new file mode 100644 index 0000000..1c81d3d --- /dev/null +++ b/other/nginx-site.conf @@ -0,0 +1,38 @@ +server +{ + listen 443 ssl http2 default_server; + server_name _; + + proxy_set_header Host $host; + proxy_set_header Connection $http_connection; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Scheme $scheme; + + location ~* ^/api/(.*?|v1|openai)/(v1|(generate|stream)|(chat/completions|completions))$ + { + # Route to inference endpoints + proxy_pass http://127.0.0.1:5000; + + # Required for streaming (both websockets and SSE). + proxy_buffering off; + proxy_cache off; + proxy_http_version 1.1; + proxy_set_header Upgrade $http_upgrade; + proxy_set_header Connection "upgrade"; + + # Set long timeouts for inference operations. + # Cloudflare has a timeout of 100 seconds. + proxy_read_timeout 120; + proxy_connect_timeout 120; + proxy_send_timeout 120; + } + + location / + { + proxy_pass http://127.0.0.1:5000; + } + + ssl_certificate /etc/ssl/certs/nginx-selfsigned.crt; + ssl_certificate_key /etc/ssl/private/nginx-selfsigned.key; + include /etc/nginx/snippets/ssl-params.conf; +} diff --git a/other/tests/config.sh b/other/tests/config.sh new file mode 100644 index 0000000..64bea46 --- /dev/null +++ b/other/tests/config.sh @@ -0,0 +1,11 @@ +HOST="proxy.chub-archive.evulid.cc" + +AUTH_KEY="TEST_1df979f0-6df1-41bd-814a-e99b1680e727" + +PROXY_SERVERS=( + "http://172.0.4.7:3128" + "http://172.0.4.8:3128" + "http://172.0.4.10:3128" + "http://172.0.4.12:3128" + "http://172.0.4.13:3128" +) diff --git a/other/tests/generate.sh b/other/tests/generate.sh new file mode 100755 index 0000000..b1443c0 --- /dev/null +++ b/other/tests/generate.sh @@ -0,0 +1,58 @@ +#!/bin/bash + +SLEEP_TIME=2 + +while getopts p:t: flag; do + case "${flag}" in + p) PROXY_CHOICE=${OPTARG} ;; + t) SLEEP_TIME=${OPTARG} ;; + *) ;; + esac +done + +SOURCE=${BASH_SOURCE[0]} +while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink + DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd) + SOURCE=$(readlink "$SOURCE") + [[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located +done +DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd) + +source "$DIR/config.sh" + +if [ -n "$PROXY_CHOICE" ]; then + our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}" + echo "Using $our_proxy_server" +else + our_proxy_server="" +fi + +while true; do + echo "--> START <--" + + DATA=$( + cat < DONE <--\n" + sleep $SLEEP_TIME +done diff --git a/other/tests/oai-chat-completion.sh b/other/tests/oai-chat-completion.sh new file mode 100755 index 0000000..5355a8a --- /dev/null +++ b/other/tests/oai-chat-completion.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +DO_STREAM=false +SLEEP_TIME=2 + +while getopts p:t:s flag; do + case "${flag}" in + s) DO_STREAM=true ;; + p) PROXY_CHOICE=${OPTARG} ;; + t) SLEEP_TIME=${OPTARG} ;; + *) ;; + esac +done + +SOURCE=${BASH_SOURCE[0]} +while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink + DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd) + SOURCE=$(readlink "$SOURCE") + [[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located +done +DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd) + +source "$DIR/config.sh" + +if [ ! -z "$PROXY_CHOICE" ]; then + our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}" + echo "Using $our_proxy_server" +else + our_proxy_server="" +fi + +while true; do + echo "--> START <--" + + DATA=$( + cat < DONE <--\n" + sleep $SLEEP_TIME +done diff --git a/other/tests/oai-completion.sh b/other/tests/oai-completion.sh new file mode 100755 index 0000000..cc0f9f0 --- /dev/null +++ b/other/tests/oai-completion.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +DO_STREAM=false +SLEEP_TIME=2 + +while getopts p:t:s flag; do + case "${flag}" in + s) DO_STREAM=true ;; + p) PROXY_CHOICE=${OPTARG} ;; + t) SLEEP_TIME=${OPTARG} ;; + *) ;; + esac +done + +SOURCE=${BASH_SOURCE[0]} +while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink + DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd) + SOURCE=$(readlink "$SOURCE") + [[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located +done +DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd) + +source "$DIR/config.sh" + +if [ ! -z "$PROXY_CHOICE" ]; then + our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}" + echo "Using $our_proxy_server" +else + our_proxy_server="" +fi + +while true; do + echo "--> START <--" + + DATA=$( + cat < DONE <--\n" + sleep $SLEEP_TIME +done diff --git a/other/tests/start-bulk.sh b/other/tests/start-bulk.sh new file mode 100755 index 0000000..6f254d5 --- /dev/null +++ b/other/tests/start-bulk.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# Function to display help message +function display_help { + echo "Usage: $0 -n num_windows -c command" + echo + echo " -n, --number Number of windows to create" + echo " -c, --command Command to run in each window" + echo + exit 1 +} + +# Parse command line arguments +while getopts "n:c:h" opt; do + case ${opt} in + n) + num_windows=${OPTARG} + ;; + c) + command=${OPTARG} + ;; + h) + display_help + ;; + \?) + echo "Invalid option: -$OPTARG" 1>&2 + display_help + ;; + :) + echo "Option -$OPTARG requires an argument." 1>&2 + display_help + ;; + esac +done + +# Check if number of windows and command are provided +if [ -z "$num_windows" ] || [ -z "$command" ]; then + echo "Both number of windows and command are required." + display_help +fi + +# Calculate rows and columns +rows=$(echo "sqrt($num_windows)" | bc) +columns=$(echo "($num_windows + $rows - 1) / $rows" | bc) + +# Create a new tmux session +tmux new-session -d -s llm_tester "$command -p 0" + +# Create the remaining windows +for ((i = 1; i < $num_windows; i++)); do + if ((i % $columns == 0)); then + tmux select-layout -t llm_tester:0 tiled + tmux select-pane -t 0 + tmux split-window -t llm_tester:0 -v "$command -p $i" + else + tmux split-window -t llm_tester:0 -h "$command -p $i" + fi +done + +# Balance the windows +tmux select-layout -t llm_tester:0 tiled + +# Attach to the session +tmux attach-session -t llm_tester diff --git a/other/ooba-test-streaming.py b/other/tests/stream.py old mode 100644 new mode 100755 similarity index 52% rename from other/ooba-test-streaming.py rename to other/tests/stream.py index 883c2f5..75d403b --- a/other/ooba-test-streaming.py +++ b/other/tests/stream.py @@ -1,37 +1,50 @@ import asyncio import json import sys +import os +import time +from pathlib import Path try: import websockets except ImportError: print("Websockets package not found. Make sure it's installed.") -# For local streaming, the websockets are hosted without ssl - ws:// -HOST = 'localhost:5000' -URI = f'ws://{HOST}/api/v1/stream' +script_path = os.path.dirname(os.path.realpath(__file__)) -# For reverse-proxied streaming, the remote will likely host with ssl - wss:// -# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' + +def parse_bash_config(file_path): + config = {} + with open(file_path, 'r') as f: + for line in f: + if line.startswith('#') or '=' not in line: + continue + key, value = line.strip().split('=', 1) + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + elif value.startswith('(') and value.endswith(')'): + value = value[1:-1].split() + value = [v.strip('"') for v in value] + config[key] = value + return config + + +config = parse_bash_config(Path(script_path, 'config.sh')) async def run(context): - # Note: the selected defaults change from time to time. request = { 'prompt': context, 'max_new_tokens': 250, 'auto_max_new_tokens': False, 'max_tokens_second': 0, - - # Generation params. If 'preset' is set to different than 'None', the values - # in presets/preset-name.yaml are used instead of the individual numbers. 'preset': 'None', 'do_sample': True, 'temperature': 0.7, 'top_p': 0.1, 'typical_p': 1, - 'epsilon_cutoff': 0, # In units of 1e-4 - 'eta_cutoff': 0, # In units of 1e-4 + 'epsilon_cutoff': 0, + 'eta_cutoff': 0, 'tfs': 1, 'top_a': 0, 'repetition_penalty': 1.18, @@ -48,7 +61,6 @@ async def run(context): 'mirostat_eta': 0.1, 'guidance_scale': 1, 'negative_prompt': '', - 'seed': -1, 'add_bos_token': True, 'truncation_length': 2048, @@ -58,7 +70,7 @@ async def run(context): 'stopping_strings': [] } - async with websockets.connect(URI, ping_interval=None) as websocket: + async with websockets.connect(f'wss://{config["HOST"]}/api/v1/stream', ping_interval=None) as websocket: await websocket.send(json.dumps(request)) yield context # Remove this if you just want to see the reply @@ -67,20 +79,28 @@ async def run(context): incoming_data = await websocket.recv() incoming_data = json.loads(incoming_data) + print(incoming_data) + match incoming_data['event']: - case 'text_stream': - yield incoming_data['text'] + # case 'text_stream': + # yield incoming_data['text'] case 'stream_end': return async def print_response_stream(prompt): - async for response in run(prompt): - print(response, end='') - sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. - print('\n\nfinished') + try: + async for response in run(prompt): + print(response, end='') + sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. + except Exception as e: + print(e) if __name__ == '__main__': - prompt = "In order to make homemade bread, follow these steps:\n1)" - asyncio.run(print_response_stream(prompt)) + prompt = "Write a 300 word story about an apple tree.\n\n" + while True: + print('--> START <--') + asyncio.run(print_response_stream(prompt)) + print('--> DONE <--') + time.sleep(2) diff --git a/requirements.txt b/requirements.txt index 9b0c8eb..89f4be7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,21 +1,18 @@ flask~=2.3.3 -flask_cors pyyaml~=6.0.1 -flask_caching +Flask-Caching==2.0.2 requests~=2.31.0 tiktoken~=0.5.0 -gunicorn gevent~=23.9.0.post1 -async-timeout -flask-sock -uvicorn~=0.23.2 -fastapi~=0.103.1 -torch~=2.0.1 PyMySQL~=1.1.0 -DBUtils~=3.0.3 simplejson~=3.19.1 websockets~=11.0.3 basicauth~=1.0.0 openai~=0.28.0 -urllib3~=2.0.4 -celery[redis] +flask-sock==0.6.0 +gunicorn==21.2.0 +redis==5.0.1 +ujson==5.8.0 +vllm==0.2.1.post1 +gradio~=3.46.1 +coloredlogs~=15.0.1 \ No newline at end of file diff --git a/server.py b/server.py index 06482d4..aa8ef1a 100644 --- a/server.py +++ b/server.py @@ -1,5 +1,3 @@ -from llm_server.config.config import mode_ui_names - try: import gevent.monkey @@ -7,37 +5,46 @@ try: except ImportError: pass -from llm_server.pre_fork import server_startup -from llm_server.config.load import load_config import os import sys from pathlib import Path import simplejson as json -from flask import Flask, jsonify, render_template, request +from flask import Flask, jsonify, render_template, request, Response -import llm_server +import config +from llm_server import opts +from llm_server.cluster.backend import get_model_choices +from llm_server.cluster.cluster_config import cluster_config +from llm_server.config.config import mode_ui_names +from llm_server.config.load import load_config +from llm_server.custom_redis import flask_cache, redis from llm_server.database.conn import database from llm_server.database.create import create_db -from llm_server.llm import get_token_count -from llm_server.routes.openai import openai_bp +from llm_server.helpers import auto_set_base_client_api +from llm_server.llm.vllm.info import vllm_info +from llm_server.pre_fork import server_startup +from llm_server.routes.openai import openai_bp, openai_model_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp -from llm_server.stream import init_socketio +from llm_server.routes.v1.generate_stats import generate_stats +from llm_server.sock import init_wssocket -# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. -# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail -# TODO: implement background thread to test backends via sending test prompts -# TODO: if backend fails request, mark it as down -# TODO: allow setting concurrent gens per-backend -# TODO: set the max tokens to that of the lowest backend -# TODO: implement RRD backend loadbalancer option -# TODO: have VLLM reject a request if it already has n == concurrent_gens running -# TODO: add a way to cancel VLLM gens. Maybe use websockets? -# TODO: use coloredlogs -# TODO: need to update opts. for workers +# TODO: seperate queue item timeout for websockets (make longer, like 5 minutes) +# TODO: return an `error: True`, error code, and error message rather than just a formatted message +# TODO: what happens when all backends are offline? What about the "online" key in the stats page? +# TODO: redis SCAN vs KEYS?? +# TODO: is frequency penalty the same as ooba repetition penalty??? +# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions # Lower priority +# TODO: if a backend is at its limit of concurrent requests, choose a different one +# TODO: make error messages consitient +# TODO: support logit_bias on OpenAI and Ooba endpoints. +# TODO: add a way to cancel VLLM gens. Maybe use websockets? +# TODO: validate openai_silent_trim works as expected and only when enabled +# TODO: rewrite config storage. Store in redis so we can reload it. +# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens # TODO: the estiamted wait time lags behind the stats # TODO: simulate OpenAI error messages regardless of endpoint @@ -59,19 +66,16 @@ except ModuleNotFoundError as e: print('Please see README.md for install instructions.') sys.exit(1) -import config -from llm_server import opts -from llm_server.helpers import auto_set_base_client_api -from llm_server.llm.vllm.info import vllm_info -from llm_server.routes.cache import RedisWrapper, flask_cache -from llm_server.llm import redis -from llm_server.routes.stats import get_active_gen_workers -from llm_server.routes.v1.generate_stats import generate_stats - app = Flask(__name__) -init_socketio(app) -app.register_blueprint(bp, url_prefix='/api/v1/') + +# Fixes ConcurrentObjectUseError +# https://github.com/miguelgrinberg/simple-websocket/issues/24 +app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25} + +app.register_blueprint(bp, url_prefix='/api/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') +app.register_blueprint(openai_model_bp, url_prefix='/api/openai/') +init_wssocket(app) flask_cache.init_app(app) flask_cache.clear() @@ -82,18 +86,13 @@ if config_path_environ: else: config_path = Path(script_path, 'config', 'config.yml') -success, config, msg = load_config(config_path, script_path) +success, config, msg = load_config(config_path) if not success: print('Failed to load config:', msg) sys.exit(1) database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) create_db() -llm_server.llm.redis = RedisWrapper('local_llm') -create_db() - - -# print(app.url_map) @app.route('/') @@ -101,20 +100,30 @@ create_db() @app.route('/api/openai') @flask_cache.cached(timeout=10) def home(): + base_client_api = redis.get('base_client_api', dtype=str) stats = generate_stats() + model_choices, default_model = get_model_choices() - if not stats['online']: - running_model = estimated_wait_sec = 'offline' - else: - running_model = redis.get('running_model', str, 'ERROR') + if default_model: + if not model_choices.get(default_model): + return 'The server is still starting up. Please wait...' - active_gen_workers = get_active_gen_workers() - if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: + default_model_info = model_choices[default_model] + + if default_model_info['queued'] == 0 and default_model_info['queued'] >= default_model_info['concurrent_gens']: # There will be a wait if the queue is empty but prompts are processing, but we don't # know how long. - estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds" + default_estimated_wait_sec = f"less than {int(default_model_info['estimated_wait'])} seconds" else: - estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds" + default_estimated_wait_sec = f"{int(default_model_info['estimated_wait'])} seconds" + else: + default_model_info = { + 'model': 'OFFLINE', + 'processing': '-', + 'queued': '-', + 'context_size': '-', + } + default_estimated_wait_sec = 'OFFLINE' if len(config['analytics_tracking_code']): analytics_tracking_code = f"" @@ -127,32 +136,47 @@ def home(): info_html = '' mode_info = '' - if opts.mode == 'vllm': - mode_info = vllm_info - - base_client_api = redis.get('base_client_api', str) + for k, v in cluster_config.all().items(): + if v['mode'] == 'vllm': + mode_info = vllm_info + break return render_template('home.html', llm_middleware_name=opts.llm_middleware_name, analytics_tracking_code=analytics_tracking_code, info_html=info_html, - current_model=opts.manual_model_name if opts.manual_model_name else running_model, + default_model=default_model_info['model'], + default_active_gen_workers=default_model_info['processing'], + default_proompters_in_queue=default_model_info['queued'], + current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model, client_api=f'https://{base_client_api}', - ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, - estimated_wait=estimated_wait_sec, - mode_name=mode_ui_names[opts.mode][0], - api_input_textbox=mode_ui_names[opts.mode][1], - streaming_input_textbox=mode_ui_names[opts.mode][2], - context_size=opts.context_size, + ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled', + default_estimated_wait=default_estimated_wait_sec, + mode_name=mode_ui_names[opts.frontend_api_mode][0], + api_input_textbox=mode_ui_names[opts.frontend_api_mode][1], + streaming_input_textbox=mode_ui_names[opts.frontend_api_mode][2], + default_context_size=default_model_info['context_size'], stats_json=json.dumps(stats, indent=4, ensure_ascii=False), extra_info=mode_info, openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled', expose_openai_system_prompt=opts.expose_openai_system_prompt, enable_streaming=opts.enable_streaming, + model_choices=model_choices, + proompters_5_min=stats['stats']['proompters']['5_min'], + proompters_24_hrs=stats['stats']['proompters']['24_hrs'], ) -# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend +@app.route('/robots.txt') +def robots(): + # TODO: have config value to deny all + # TODO: https://developers.google.com/search/docs/crawling-indexing/robots/create-robots-txt + t = """User-agent: * +Allow: /""" + r = Response(t) + r.headers['Content-Type'] = 'text/plain' + return r + @app.route('/') @app.route('//') diff --git a/templates/home.html b/templates/home.html index 4b9c153..66340a6 100644 --- a/templates/home.html +++ b/templates/home.html @@ -65,6 +65,19 @@ .hidden { display: none; } + + .header-workers { + font-weight: normal; + font-size: 14pt; + } + + h3 { + font-size: 16pt; + } + + .no-marker { + list-style: none; + } @@ -76,8 +89,12 @@

{{ llm_middleware_name }}

-

Current Model: {{ current_model }}

-

Estimated Wait Time: {{ estimated_wait }}

+

Current Model: {{ default_model }}

+

+ Estimated Wait Time: {{ default_estimated_wait }}
+ Processing: {{ default_active_gen_workers }}
+ Queued: {{ default_proompters_in_queue }} +


Client API URL: {{ client_api }}

Streaming API URL: {{ ws_client_api if enable_streaming else 'Disabled' }}

@@ -91,17 +108,20 @@
-
- Instructions: +

Instructions

+
    +
  1. In Settings > Power User Options, enable Relaxed API URLS.
  2. Set your API type to {{ mode_name }}
  3. Enter {{ client_api }} in the {{ api_input_textbox }} textbox.
  4. - {% if enable_streaming %}
  5. Enter {{ ws_client_api }} in the {{ streaming_input_textbox }} textbox.
  6. {% endif %} + {% if enable_streaming %} +
  7. Enter {{ ws_client_api }} in the {{ streaming_input_textbox }} textbox.
  8. + {% endif %}
  9. If you have a token, check the Mancer AI checkbox and enter your token in the Mancer API key textbox.
  10. Click Connect to test the connection.
  11. -
  12. Open your preset config and set Context Size to {{ context_size }}.
  13. +
  14. Open your preset config and set Context Size to {{ default_context_size }}.
  15. Follow this guide to get set up: rentry.org/freellamas
@@ -120,13 +140,45 @@
-
{{ stats_json|safe }}
+

Statistics

+ Proompters: +
    +
  • 5 minutes: {{ proompters_5_min }}
  • +
  • 24 hours: {{ proompters_24_hrs }}
  • +
+
+ + {% for key, value in model_choices.items() %} +
+

{{ key }} - {{ value.backend_count }} {% if value.backend_count == 1 %}worker{% else %}workers{% endif %}

+ + {% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %} + {# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #} + {% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %} + {% else %} + {% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %} + {% endif %} + +

+ Estimated Wait Time: {{ estimated_wait_sec }}
+ Processing: {{ value.processing }}
+ Queued: {{ value.queued }}
+

+

+ Client API URL: {{ value.client_api }}
+ Streaming API URL: {{ value.ws_client_api }}
+ OpenAI-Compatible API URL: {{ value.openai_client_api }} +

+

Context Size: {{ value.context_size }}

+

Average Generation Time: {{ value.avg_generation_time | int }} seconds

+
+
+ {% endfor %}
-