diff --git a/README.md b/README.md
index 429e246..ccfaaf4 100644
--- a/README.md
+++ b/README.md
@@ -43,7 +43,9 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
### Use
+If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database.
+Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
### To Do
diff --git a/daemon.py b/daemon.py
index 20ec300..69e8532 100644
--- a/daemon.py
+++ b/daemon.py
@@ -1,22 +1,19 @@
-import time
-
-from llm_server.routes.cache import redis
-
-try:
- import gevent.monkey
-
- gevent.monkey.patch_all()
-except ImportError:
- pass
-
+import argparse
+import logging
import os
import sys
+import time
from pathlib import Path
-from llm_server.config.load import load_config
-from llm_server.database.create import create_db
+from redis import Redis
-from llm_server.workers.app import start_background
+from llm_server.cluster.cluster_config import cluster_config
+from llm_server.config.load import load_config, parse_backends
+from llm_server.custom_redis import redis
+from llm_server.database.create import create_db
+from llm_server.logging import create_logger, logging_info, init_logging
+from llm_server.routes.v1.generate_stats import generate_stats
+from llm_server.workers.threader import start_background
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
@@ -26,19 +23,46 @@ else:
config_path = Path(script_path, 'config', 'config.yml')
if __name__ == "__main__":
- flushed_keys = redis.flush()
- print('Flushed', len(flushed_keys), 'keys from Redis.')
+ parser = argparse.ArgumentParser(description='Daemon microservice.')
+ parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.")
+ parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
+ args = parser.parse_args()
- success, config, msg = load_config(config_path, script_path)
+ # TODO: have this be set by either the arg or a config value
+ if args.debug:
+ logging_info.level = logging.DEBUG
+
+ init_logging()
+ logger = create_logger('daemon')
+ logger.debug('Debug logging enabled.')
+
+ if not args.no_reset:
+ Redis().flushall()
+ logger.info('Flushed Redis.')
+
+ success, config, msg = load_config(config_path)
if not success:
- print('Failed to load config:', msg)
+ logger.info(f'Failed to load config: {msg}')
sys.exit(1)
create_db()
+
+ cluster_config.clear()
+ cluster_config.load(parse_backends(config))
+
+ logger.info('Loading backend stats...')
+ generate_stats(regen=True)
+
start_background()
- redis.set('daemon_started', 1)
- print('== Daemon Setup Complete ==\n')
+ # Give some time for the background threads to get themselves ready to go.
+ time.sleep(2)
- while True:
- time.sleep(3600)
+ redis.set('daemon_started', 1)
+ logger.info('== Daemon Setup Complete ==')
+
+ try:
+ while True:
+ time.sleep(3600)
+ except KeyboardInterrupt:
+ redis.set('daemon_started', 0)
diff --git a/llm_server/cluster/__init__.py b/llm_server/cluster/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/llm_server/cluster/backend.py b/llm_server/cluster/backend.py
new file mode 100644
index 0000000..9e2e19b
--- /dev/null
+++ b/llm_server/cluster/backend.py
@@ -0,0 +1,117 @@
+import numpy as np
+
+from llm_server import opts
+from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
+from llm_server.cluster.stores import redis_running_models
+from llm_server.custom_redis import redis
+from llm_server.llm.generator import generator
+from llm_server.llm.info import get_info
+from llm_server.llm.vllm.vllm_backend import VLLMBackend
+from llm_server.routes.queue import priority_queue
+from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model
+
+
+def get_backends_from_model(model_name: str):
+ return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
+
+
+def get_running_models():
+ return redis_running_models.keys()
+
+
+def purge_backend_from_running_models(backend_url: str):
+ keys = redis_running_models.keys()
+ pipeline = redis_running_models.pipeline()
+ for model in keys:
+ pipeline.srem(model, backend_url)
+ pipeline.execute()
+
+
+def is_valid_model(model_name: str):
+ return redis_running_models.exists(model_name)
+
+
+def test_backend(backend_url: str, test_prompt: bool = False):
+ backend_info = cluster_config.get_backend(backend_url)
+ if test_prompt:
+ handler = VLLMBackend(backend_url)
+ parameters, _ = handler.get_parameters({
+ "stream": False,
+ "temperature": 0,
+ "max_new_tokens": 3,
+ })
+ data = {
+ 'prompt': 'test prompt',
+ **parameters
+ }
+ try:
+ success, response, err = generator(data, backend_url, timeout=10)
+ if not success or not response or err:
+ return False, {}
+ except:
+ return False, {}
+ i = get_info(backend_url, backend_info['mode'])
+ if not i.get('model'):
+ return False, {}
+ return True, i
+
+
+def get_model_choices(regen: bool = False):
+ if not regen:
+ c = redis.getp('model_choices')
+ if c:
+ return c
+
+ base_client_api = redis.get('base_client_api', dtype=str)
+ running_models = get_running_models()
+ model_choices = {}
+ for model in running_models:
+ b = get_backends_from_model(model)
+
+ context_size = []
+ avg_gen_per_worker = []
+ concurrent_gens = 0
+ for backend_url in b:
+ backend_info = cluster_config.get_backend(backend_url)
+ if backend_info.get('model_config'):
+ context_size.append(backend_info['model_config']['max_position_embeddings'])
+ if backend_info.get('average_generation_elapsed_sec'):
+ avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec'])
+ concurrent_gens += backend_info['concurrent_gens']
+
+ active_gen_workers = get_active_gen_workers_model(model)
+ proompters_in_queue = priority_queue.len(model)
+
+ if len(avg_gen_per_worker):
+ average_generation_elapsed_sec = np.average(avg_gen_per_worker)
+ else:
+ average_generation_elapsed_sec = 0
+ estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, concurrent_gens, active_gen_workers)
+
+ model_choices[model] = {
+ 'model': model,
+ 'client_api': f'https://{base_client_api}/{model}',
+ 'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
+ 'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled',
+ 'backend_count': len(b),
+ 'estimated_wait': estimated_wait_sec,
+ 'queued': proompters_in_queue,
+ 'processing': active_gen_workers,
+ 'avg_generation_time': average_generation_elapsed_sec,
+ 'concurrent_gens': concurrent_gens
+ }
+
+ if len(context_size):
+ model_choices[model]['context_size'] = min(context_size)
+
+ # Python wants to sort lowercase vs. uppercase letters differently.
+ model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper()))
+
+ default_backend_url = get_a_cluster_backend()
+ default_backend_info = cluster_config.get_backend(default_backend_url)
+ if not default_backend_info.get('model'):
+ return {}, None
+ default_model = default_backend_info['model']
+
+ redis.setp('model_choices', (model_choices, default_model))
+ return model_choices, default_model
diff --git a/llm_server/cluster/cluster_config.py b/llm_server/cluster/cluster_config.py
new file mode 100644
index 0000000..891dfc1
--- /dev/null
+++ b/llm_server/cluster/cluster_config.py
@@ -0,0 +1,124 @@
+import hashlib
+import pickle
+import traceback
+
+from llm_server import opts
+from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
+from llm_server.cluster.stores import redis_running_models
+from llm_server.custom_redis import RedisCustom
+from llm_server.routes.helpers.model import estimate_model_size
+
+
+class RedisClusterStore:
+ def __init__(self, name: str, **kwargs):
+ self.name = name
+ self.config_redis = RedisCustom(name, **kwargs)
+
+ def clear(self):
+ self.config_redis.flush()
+
+ def load(self, config: dict):
+ for k, v in config.items():
+ self.add_backend(k, v)
+
+ def add_backend(self, name: str, values: dict):
+ self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
+ self.set_backend_value(name, 'online', False)
+ h = hashlib.sha256(name.encode('utf-8')).hexdigest()
+ self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}')
+
+ def set_backend_value(self, backend: str, key: str, value):
+ # By storing the value as a pickle we don't have to cast anything when getting the value from Redis.
+ self.config_redis.hset(backend, key, pickle.dumps(value))
+
+ def get_backend(self, name: str):
+ r = self.config_redis.hgetall(name)
+ output = {}
+ for k, v in r.items():
+ output[k.decode('utf8')] = pickle.loads(v)
+ return output
+
+ def all(self):
+ keys = self.config_redis.keys('*')
+ if keys:
+ result = {}
+ for key in keys:
+ if key != f'{self.name}:____':
+ v = self.get_backend(key)
+ result[key] = v
+ return result
+ else:
+ return {}
+
+ def validate_backend(self, backend_url: str):
+ """
+ Returns the backend URL that was given, or a new one if that was offline.
+ :param backend_url:
+ :return:
+ """
+ backend_info = self.get_backend(backend_url)
+ if not backend_info['online']:
+ old = backend_url
+ backend_url = get_a_cluster_backend()
+ print(f'Backend {old} offline. Request was redirected to {backend_url}')
+ return backend_url
+
+
+cluster_config = RedisClusterStore('cluster_config')
+
+
+def get_backends():
+ backends = cluster_config.all()
+ result = {}
+ for k, v in backends.items():
+ b = cluster_config.get_backend(k)
+ status = b.get('online', False)
+ priority = b['priority']
+ result[k] = {'status': status, 'priority': priority}
+
+ try:
+ if not opts.prioritize_by_size:
+ online_backends = sorted(
+ ((url, info) for url, info in backends.items() if info['online']),
+ key=lambda kv: -kv[1]['priority'],
+ reverse=True
+ )
+ else:
+ online_backends = sorted(
+ ((url, info) for url, info in backends.items() if info['online']),
+ key=lambda kv: estimate_model_size(kv[1]['model_config']),
+ reverse=True
+ )
+ offline_backends = sorted(
+ ((url, info) for url, info in backends.items() if not info['online']),
+ key=lambda kv: -kv[1]['priority'],
+ reverse=True
+ )
+ return [url for url, info in online_backends], [url for url, info in offline_backends]
+ except KeyError:
+ traceback.print_exc()
+ print(backends)
+
+
+def get_a_cluster_backend(model=None):
+ """
+ Get a backend from Redis. If there are no online backends, return None.
+ If `model` is not supplied, we will pick one ourself.
+ """
+ if model:
+ # First, determine if there are multiple backends hosting the same model.
+ backends_hosting_model = [i.decode('utf-8') for i in redis_running_models.smembers(model)]
+
+ # If so, create an iterator for those backends
+ if len(backends_hosting_model):
+ add_backend_cycler(model, backends_hosting_model)
+ cycled = redis_cycle(model)
+ if len(cycled):
+ return cycled[0]
+ else:
+ # No backend hosting that model
+ return None
+ else:
+ online, _ = get_backends()
+ if len(online):
+ return online[0]
diff --git a/llm_server/cluster/redis_cycle.py b/llm_server/cluster/redis_cycle.py
new file mode 100644
index 0000000..266241d
--- /dev/null
+++ b/llm_server/cluster/redis_cycle.py
@@ -0,0 +1,39 @@
+import redis
+
+redis_cycler_db = redis.Redis(host='localhost', port=6379, db=9)
+
+
+def redis_cycle(list_name):
+ """
+ Emulates itertools.cycle() but returns the complete shuffled list.
+ :param list_name:
+ :return:
+ """
+ pipeline = redis_cycler_db.pipeline()
+ pipeline.lpop(list_name)
+ to_move = pipeline.execute()[0]
+ if not to_move:
+ return []
+ pipeline.rpush(list_name, to_move)
+ pipeline.lrange(list_name, 0, -1)
+ results = pipeline.execute()
+ new_list = results[-1]
+ return [x.decode('utf-8') for x in new_list]
+
+
+def add_backend_cycler(list_name: str, new_elements: list):
+ existing_elements = [i.decode('utf-8') for i in redis_cycler_db.lrange(list_name, 0, -1)]
+ existing_set = set(existing_elements)
+
+ with redis_cycler_db.pipeline() as pipe:
+ # Add elements
+ for element in new_elements:
+ if element not in existing_set:
+ pipe.rpush(list_name, element)
+
+ # Remove elements
+ for element in existing_set:
+ if element not in new_elements:
+ pipe.lrem(list_name, 0, element)
+
+ pipe.execute()
diff --git a/llm_server/cluster/stores.py b/llm_server/cluster/stores.py
new file mode 100644
index 0000000..c0cbdcc
--- /dev/null
+++ b/llm_server/cluster/stores.py
@@ -0,0 +1,3 @@
+from llm_server.custom_redis import RedisCustom
+
+redis_running_models = RedisCustom('running_models')
diff --git a/llm_server/cluster/worker.py b/llm_server/cluster/worker.py
new file mode 100644
index 0000000..9652db9
--- /dev/null
+++ b/llm_server/cluster/worker.py
@@ -0,0 +1,38 @@
+import time
+from threading import Thread
+
+from llm_server.cluster.backend import test_backend
+from llm_server.cluster.cluster_config import cluster_config
+from llm_server.cluster.stores import redis_running_models
+
+
+def cluster_worker():
+ counter = 0
+ while True:
+ test_prompt = False
+ if counter % 4 == 0:
+ # Only send a test prompt every 120 seconds.
+ test_prompt = True
+ threads = []
+ for n, v in cluster_config.all().items():
+ thread = Thread(target=check_backend, args=(n, v, test_prompt))
+ thread.start()
+ threads.append(thread)
+ for thread in threads:
+ thread.join()
+ time.sleep(15)
+ counter += 1
+
+
+def check_backend(n, v, test_prompt):
+ online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
+ if online:
+ running_model = backend_info['model']
+ for k, v in backend_info.items():
+ cluster_config.set_backend_value(n, k, v)
+ redis_running_models.sadd(running_model, n)
+ else:
+ for model in redis_running_models.keys():
+ redis_running_models.srem(model, n)
+
+ cluster_config.set_backend_value(n, 'online', online)
diff --git a/llm_server/config/config.py b/llm_server/config/config.py
index 59568d7..2c08544 100644
--- a/llm_server/config/config.py
+++ b/llm_server/config/config.py
@@ -28,16 +28,19 @@ config_default_vars = {
'openai_force_no_hashes': True,
'include_system_tokens_in_stats': True,
'openai_moderation_scan_last_n': 5,
- 'openai_moderation_workers': 10,
'openai_org_name': 'OpenAI',
'openai_silent_trim': False,
'openai_moderation_enabled': True,
- 'netdata_root': None
+ 'netdata_root': None,
+ 'show_backends': True,
+ 'background_homepage_cacher': True,
+ 'openai_moderation_timeout': 5,
+ 'prioritize_by_size': False
}
-config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
+config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name']
mode_ui_names = {
- 'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
+ 'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
}
diff --git a/llm_server/config/load.py b/llm_server/config/load.py
index 64469b2..cc3250c 100644
--- a/llm_server/config/load.py
+++ b/llm_server/config/load.py
@@ -3,38 +3,28 @@ import sys
import openai
+import llm_server
from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
+from llm_server.custom_redis import redis
from llm_server.database.conn import database
from llm_server.database.database import get_number_of_rows
-from llm_server.helpers import resolve_path
-from llm_server.routes.cache import redis
+from llm_server.routes.queue import PriorityQueue
-def load_config(config_path, script_path):
+def load_config(config_path):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
success, config, msg = config_loader.load_config()
if not success:
return success, config, msg
- # Resolve relative directory to the directory of the script
- if config['database_path'].startswith('./'):
- config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
-
- if config['mode'] not in ['oobabooga', 'vllm']:
- print('Unknown mode:', config['mode'])
- sys.exit(1)
-
# TODO: this is atrocious
- opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
- opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client']
- opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
- opts.backend_url = config['backend_url'].strip('/')
+ opts.cluster = config['cluster']
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
@@ -53,10 +43,20 @@ def load_config(config_path, script_path):
opts.openai_force_no_hashes = config['openai_force_no_hashes']
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
- opts.openai_moderation_workers = config['openai_moderation_workers']
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
+ opts.show_backends = config['show_backends']
+ opts.background_homepage_cacher = config['background_homepage_cacher']
+ opts.openai_moderation_timeout = config['openai_moderation_timeout']
+ opts.frontend_api_mode = config['frontend_api_mode']
+ opts.prioritize_by_size = config['prioritize_by_size']
+
+ # Scale the number of workers.
+ for item in config['cluster']:
+ opts.cluster_workers += item['concurrent_gens']
+
+ llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
@@ -78,6 +78,16 @@ def load_config(config_path, script_path):
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
- redis.set('backend_mode', opts.mode)
-
return success, config, msg
+
+
+def parse_backends(config):
+ if not config.get('cluster'):
+ return False
+ cluster = config.get('cluster')
+ config = {}
+ for item in cluster:
+ backend_url = item['backend_url'].strip('/')
+ item['backend_url'] = backend_url
+ config[backend_url] = item
+ return config
diff --git a/llm_server/config/redis_config.py b/llm_server/config/redis_config.py
new file mode 100644
index 0000000..06ab1d3
--- /dev/null
+++ b/llm_server/config/redis_config.py
@@ -0,0 +1,3 @@
+from llm_server.custom_redis import RedisCustom
+
+redis_config = RedisCustom('redis_config')
diff --git a/llm_server/routes/cache.py b/llm_server/custom_redis.py
similarity index 52%
rename from llm_server/routes/cache.py
rename to llm_server/custom_redis.py
index d7046db..a055537 100644
--- a/llm_server/routes/cache.py
+++ b/llm_server/custom_redis.py
@@ -1,24 +1,27 @@
+import pickle
import sys
import traceback
-from typing import Callable, List, Mapping, Union
+from typing import Callable, List, Mapping, Optional, Union
import redis as redis_pkg
import simplejson as json
from flask_caching import Cache
from redis import Redis
-from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT
+from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
-flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
+flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000
-class RedisWrapper:
+class RedisCustom(Redis):
"""
- A wrapper class to set prefixes to keys.
+ A simple wrapper class for Redis to create a "namespace" within a DB,
+ which simplyifies key management.
"""
def __init__(self, prefix, **kwargs):
+ super().__init__()
self.redis = Redis(**kwargs)
self.prefix = prefix
try:
@@ -34,12 +37,11 @@ class RedisWrapper:
def set(self, key, value, ex: Union[ExpiryT, None] = None):
return self.redis.set(self._key(key), value, ex=ex)
- def get(self, key, dtype=None, default=None):
- """
- :param key:
- :param dtype: convert to this type
- :return:
- """
+ def get(self, key, default=None, dtype=None):
+ # TODO: use pickle
+ import inspect
+ if inspect.isclass(default):
+ raise Exception
d = self.redis.get(self._key(key))
if dtype and d:
@@ -108,7 +110,10 @@ class RedisWrapper:
):
return self.redis.hincrby(self._key(name), key, amount)
- def hdel(self, name: str, *keys: List):
+ def zcard(self, name: KeyT):
+ return self.redis.zcard(self._key(name))
+
+ def hdel(self, name: str, *keys: str):
return self.redis.hdel(self._key(name), *keys)
def hget(
@@ -129,9 +134,62 @@ class RedisWrapper:
):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
+ def lpush(self, name: str, *values: FieldT):
+ return self.redis.lpush(self._key(name), *values)
+
+ def hset(
+ self,
+ name: str,
+ key: Optional = None,
+ value=None,
+ mapping: Optional[dict] = None,
+ items: Optional[list] = None,
+ ):
+ return self.redis.hset(self._key(name), key, value, mapping, items)
+
def hkeys(self, name: str):
return self.redis.hkeys(self._key(name))
+ def hmget(self, name: str, keys: List, *args: List):
+ return self.redis.hmget(self._key(name), keys, *args)
+
+ def hgetall(self, name: str):
+ return self.redis.hgetall(self._key(name))
+
+ def keys(self, pattern: PatternT = "*", **kwargs):
+ raw_keys = self.redis.keys(self._key(pattern), **kwargs)
+ keys = []
+ for key in raw_keys:
+ p = key.decode('utf-8').split(':')
+ if len(p) >= 2:
+ # Delete prefix
+ del p[0]
+ k = ':'.join(p)
+ if k != '____':
+ keys.append(k)
+ return keys
+
+ def pipeline(self, transaction=True, shard_hint=None):
+ return self.redis.pipeline(transaction, shard_hint)
+
+ def smembers(self, name: str):
+ return self.redis.smembers(self._key(name))
+
+ def spop(self, name: str, count: Optional[int] = None):
+ return self.redis.spop(self._key(name), count)
+
+ def rpoplpush(self, src, dst):
+ return self.redis.rpoplpush(src, dst)
+
+ def zpopmin(self, name: KeyT, count: Union[int, None] = None):
+ return self.redis.zpopmin(self._key(name), count)
+
+ def exists(self, *names: KeyT):
+ n = []
+ for name in names:
+ n.append(self._key(name))
+ return self.redis.exists(*n)
+
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex)
@@ -142,6 +200,15 @@ class RedisWrapper:
else:
return json.loads(r.decode("utf-8"))
+ def setp(self, name, value):
+ self.redis.set(self._key(name), pickle.dumps(value))
+
+ def getp(self, name: str):
+ r = self.redis.get(self._key(name))
+ if r:
+ return pickle.loads(r)
+ return r
+
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):
@@ -149,5 +216,40 @@ class RedisWrapper:
self.redis.delete(key)
return flushed
+ def flushall(self, asynchronous: bool = ..., **kwargs) -> bool:
+ self.flush()
+ return True
-redis = RedisWrapper('local_llm')
+ def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool:
+ self.flush()
+ return True
+
+ def lrange(self, name: str, start: int, end: int):
+ return self.redis.lrange(self._key(name), start, end)
+
+ def delete(self, *names: KeyT):
+ return self.redis.delete(*[self._key(i) for i in names])
+
+ def lpop(self, name: str, count: Optional[int] = None):
+ return self.redis.lpop(self._key(name), count)
+
+ def zrange(
+ self,
+ name: KeyT,
+ start: int,
+ end: int,
+ desc: bool = False,
+ withscores: bool = False,
+ score_cast_func: Union[type, Callable] = float,
+ byscore: bool = False,
+ bylex: bool = False,
+ offset: int = None,
+ num: int = None,
+ ):
+ return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
+
+ def zrem(self, name: KeyT, *values: FieldT):
+ return self.redis.zrem(self._key(name), *values)
+
+
+redis = RedisCustom('local_llm')
diff --git a/llm_server/database/conn.py b/llm_server/database/conn.py
index 25f3326..f63f555 100644
--- a/llm_server/database/conn.py
+++ b/llm_server/database/conn.py
@@ -5,20 +5,20 @@ class DatabaseConnection:
host: str = None
username: str = None
password: str = None
- database: str = None
+ database_name: str = None
- def init_db(self, host, username, password, database):
+ def init_db(self, host, username, password, database_name):
self.host = host
self.username = username
self.password = password
- self.database = database
+ self.database_name = database_name
def cursor(self):
db = pymysql.connect(
host=self.host,
user=self.username,
password=self.password,
- database=self.database,
+ database=self.database_name,
charset='utf8mb4',
autocommit=True,
)
diff --git a/llm_server/database/database.py b/llm_server/database/database.py
index 9bfe578..d6bd6b2 100644
--- a/llm_server/database/database.py
+++ b/llm_server/database/database.py
@@ -1,15 +1,19 @@
import json
import time
import traceback
+from typing import Union
-import llm_server
from llm_server import opts
+from llm_server.cluster.cluster_config import cluster_config
from llm_server.database.conn import database
-from llm_server.llm.vllm import tokenize
-from llm_server.routes.cache import redis
+from llm_server.llm import get_token_count
-def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
+def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
+ assert isinstance(prompt, str)
+ assert isinstance(backend_url, str)
+
+ # Try not to shove JSON into the database.
if isinstance(response, dict) and response.get('results'):
response = response['results'][0]['text']
try:
@@ -19,10 +23,11 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
except:
pass
- prompt_tokens = llm_server.llm.get_token_count(prompt)
+ prompt_tokens = get_token_count(prompt, backend_url)
+
if not is_error:
if not response_tokens:
- response_tokens = llm_server.llm.get_token_count(response)
+ response_tokens = get_token_count(response, backend_url)
else:
response_tokens = None
@@ -43,7 +48,9 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
if token:
increment_token_uses(token)
- running_model = redis.get('running_model', str, 'ERROR')
+ backend_info = cluster_config.get_backend(backend_url)
+ running_model = backend_info.get('model')
+ backend_mode = backend_info['mode']
timestamp = int(time.time())
cursor = database.cursor()
try:
@@ -52,7 +59,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
- (ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
+ (ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()
@@ -179,3 +186,21 @@ def increment_token_uses(token):
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
finally:
cursor.close()
+
+
+def get_token_ratelimit(token):
+ priority = 9990
+ simultaneous_ip = opts.simultaneous_requests_per_ip
+ if token:
+ cursor = database.cursor()
+ try:
+ cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
+ result = cursor.fetchone()
+ if result:
+ priority, simultaneous_ip = result
+ if simultaneous_ip is None:
+ # No ratelimit for this token if null
+ simultaneous_ip = 999999999
+ finally:
+ cursor.close()
+ return priority, simultaneous_ip
diff --git a/llm_server/database/log_to_db.py b/llm_server/database/log_to_db.py
new file mode 100644
index 0000000..75bcaab
--- /dev/null
+++ b/llm_server/database/log_to_db.py
@@ -0,0 +1,30 @@
+import pickle
+from typing import Union
+
+from redis import Redis
+
+
+def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
+ assert isinstance(prompt, str)
+ assert isinstance(backend_url, str)
+
+ r = Redis(host='localhost', port=6379, db=3)
+ data = {
+ 'function': 'log_prompt',
+ 'args': [],
+ 'kwargs': {
+ 'ip': ip,
+ 'token': token,
+ 'prompt': prompt,
+ 'response': response,
+ 'gen_time': gen_time,
+ 'parameters': parameters,
+ 'headers': dict(headers) if headers else headers,
+ 'backend_response_code': backend_response_code,
+ 'request_url': request_url,
+ 'backend_url': backend_url,
+ 'response_tokens': response_tokens,
+ 'is_error': is_error
+ }
+ }
+ r.publish('database-logger', pickle.dumps(data))
diff --git a/llm_server/helpers.py b/llm_server/helpers.py
index 44b436b..91f3b15 100644
--- a/llm_server/helpers.py
+++ b/llm_server/helpers.py
@@ -8,7 +8,7 @@ import simplejson as json
from flask import make_response
from llm_server import opts
-from llm_server.routes.cache import redis
+from llm_server.custom_redis import redis
def resolve_path(*p: str):
@@ -54,13 +54,14 @@ def jsonify_pretty(json_dict: Union[list, dict], status=200, indent=4, sort_keys
def round_up_base(n, base):
if base == 0:
- print('round_up_base DIVIDE BY ZERO ERROR????', n, base)
+ # TODO: I don't think passing (0, 0) to this function is a sign of any underlying issues.
+ # print('round_up_base DIVIDE BY ZERO ERROR????', n, base)
return 0
return math.ceil(n / base) * base
def auto_set_base_client_api(request):
- http_host = redis.get('http_host', str)
+ http_host = redis.get('http_host', dtype=str)
host = request.headers.get("Host")
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
# If the current http_host is not an IP, don't do anything.
diff --git a/llm_server/integer.py b/llm_server/integer.py
deleted file mode 100644
index 1410dd1..0000000
--- a/llm_server/integer.py
+++ /dev/null
@@ -1,12 +0,0 @@
-import threading
-
-
-class ThreadSafeInteger:
- def __init__(self, value=0):
- self.value = value
- self._value_lock = threading.Lock()
-
- def increment(self):
- with self._value_lock:
- self.value += 1
- return self.value
diff --git a/llm_server/llm/__init__.py b/llm_server/llm/__init__.py
index 742b1a5..ac6702b 100644
--- a/llm_server/llm/__init__.py
+++ b/llm_server/llm/__init__.py
@@ -1,11 +1,30 @@
+import tiktoken
+
+from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import oobabooga, vllm
-from llm_server.routes.cache import redis
+from llm_server.logging import create_logger
-def get_token_count(prompt: str):
- backend_mode = redis.get('backend_mode', str)
+def fallback_tokenizer(prompt: str):
+ tokenizer = tiktoken.get_encoding("cl100k_base")
+ return len(tokenizer.encode(prompt)) + 10
+
+
+def get_token_count(prompt: str, backend_url: str):
+ backend_url = cluster_config.validate_backend(backend_url)
+ if not backend_url:
+ logger = create_logger('tokenizer')
+ logger.warning('using fallback tokenizer as there is no valid backend')
+ return fallback_tokenizer(prompt)
+
+ backend_mode = cluster_config.get_backend(backend_url).get('mode')
+ if not backend_mode:
+ logger = create_logger('tokenizer')
+ logger.warning("using fallback tokenizer as the backend isn't initalized")
+ return fallback_tokenizer(prompt)
+
if backend_mode == 'vllm':
- return vllm.tokenize(prompt)
+ return vllm.tokenize(prompt, backend_url)
elif backend_mode == 'ooba':
return oobabooga.tokenize(prompt)
else:
diff --git a/llm_server/llm/generator.py b/llm_server/llm/generator.py
index 5dd2093..c924d38 100644
--- a/llm_server/llm/generator.py
+++ b/llm_server/llm/generator.py
@@ -1,14 +1,15 @@
from llm_server import opts
+from llm_server.cluster.cluster_config import cluster_config
-def generator(request_json_body):
- if opts.mode == 'oobabooga':
+def generator(request_json_body, cluster_backend, timeout: int = None):
+ mode = cluster_config.get_backend(cluster_backend)['mode']
+ if mode == 'ooba':
# from .oobabooga.generate import generate
# return generate(request_json_body)
raise NotImplementedError
- elif opts.mode == 'vllm':
+ elif mode == 'vllm':
from .vllm.generate import generate
- r = generate(request_json_body)
- return r
+ return generate(request_json_body, cluster_backend, timeout=timeout)
else:
raise Exception
diff --git a/llm_server/llm/info.py b/llm_server/llm/info.py
index 5a529ba..d1218e2 100644
--- a/llm_server/llm/info.py
+++ b/llm_server/llm/info.py
@@ -3,23 +3,35 @@ import requests
from llm_server import opts
-def get_running_model():
- # TODO: cache the results for 1 min so we don't have to keep calling the backend
- # TODO: only use one try/catch
-
- if opts.mode == 'oobabooga':
+def get_running_model(backend_url: str, mode: str):
+ if mode == 'ooba':
try:
- backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
+ backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json()
return r_json['result'], None
except Exception as e:
return False, e
- elif opts.mode == 'vllm':
+ elif mode == 'vllm':
try:
- backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
+ backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json()
return r_json['model'], None
except Exception as e:
return False, e
else:
raise Exception
+
+
+def get_info(backend_url: str, mode: str):
+ if mode == 'ooba':
+ return {}
+ # raise NotImplementedError
+ elif mode == 'vllm':
+ try:
+ r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout)
+ j = r.json()
+ except Exception as e:
+ return {}
+ return j
+ else:
+ raise Exception
diff --git a/llm_server/llm/llm_backend.py b/llm_server/llm/llm_backend.py
index 1c11c17..f864b18 100644
--- a/llm_server/llm/llm_backend.py
+++ b/llm_server/llm/llm_backend.py
@@ -2,14 +2,17 @@ from typing import Tuple, Union
import flask
-from llm_server import opts
+from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import get_token_count
-from llm_server.routes.cache import redis
class LLMBackend:
_default_params: dict
+ def __init__(self, backend_url: str):
+ self.backend_url = backend_url
+ self.backend_info = cluster_config.get_backend(self.backend_url)
+
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError
@@ -32,14 +35,16 @@ class LLMBackend:
"""
If a backend needs to do other checks not related to the prompt or parameters.
Default is no extra checks preformed.
+ :param request:
+ :param prompt:
:param parameters:
:return:
"""
return True, None
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
- prompt_len = get_token_count(prompt)
- if prompt_len > opts.context_size - 10:
- model_name = redis.get('running_model', str, 'NO MODEL ERROR')
- return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
+ prompt_len = get_token_count(prompt, self.backend_url)
+ token_limit = self.backend_info['model_config']['max_position_embeddings']
+ if prompt_len > token_limit - 10:
+ return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size'
return True, None
diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py
index 4336756..18fe6b1 100644
--- a/llm_server/llm/oobabooga/ooba_backend.py
+++ b/llm_server/llm/oobabooga/ooba_backend.py
@@ -1,78 +1,6 @@
-from flask import jsonify
-
from ..llm_backend import LLMBackend
-from ...database.database import log_prompt
-from ...helpers import safe_list_get
-from ...routes.cache import redis
-from ...routes.helpers.client import format_sillytavern_err
-from ...routes.helpers.http import validate_json
class OobaboogaBackend(LLMBackend):
- default_params = {}
-
- def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
- raise NotImplementedError('need to implement default_params')
-
- backend_err = False
- response_valid_json, response_json_body = validate_json(response)
- if response:
- try:
- # Be extra careful when getting attributes from the response object
- response_status_code = response.status_code
- except:
- response_status_code = 0
- else:
- response_status_code = None
-
- # ===============================================
-
- # We encountered an error
- if not success or not response or error_msg:
- if not error_msg or error_msg == '':
- error_msg = 'Unknown error.'
- else:
- error_msg = error_msg.strip('.') + '.'
- backend_response = format_sillytavern_err(error_msg, 'error')
- log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
- return jsonify({
- 'code': 500,
- 'msg': error_msg,
- 'results': [{'text': backend_response}]
- }), 400
-
- # ===============================================
-
- if response_valid_json:
- backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
- if not backend_response:
- # Ooba doesn't return any error messages so we will just tell the client an error occurred
- backend_err = True
- backend_response = format_sillytavern_err(
- f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
- 'error')
- response_json_body['results'][0]['text'] = backend_response
-
- if not backend_err:
- redis.incr('proompts')
-
- log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
- return jsonify({
- **response_json_body
- }), 200
- else:
- backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
- log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
- return jsonify({
- 'code': 500,
- 'msg': 'the backend did not return valid JSON',
- 'results': [{'text': backend_response}]
- }), 400
-
- def validate_params(self, params_dict: dict):
- # No validation required
- return True, None
-
- def get_parameters(self, parameters):
- del parameters['prompt']
- return parameters
+ def __int__(self):
+ return
diff --git a/llm_server/llm/openai/moderation.py b/llm_server/llm/openai/moderation.py
index 53e234d..f62241d 100644
--- a/llm_server/llm/openai/moderation.py
+++ b/llm_server/llm/openai/moderation.py
@@ -10,7 +10,7 @@ def check_moderation_endpoint(prompt: str):
}
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
if response.status_code != 200:
- print(response.text)
+ print('moderation failed:', response)
response.raise_for_status()
response = response.json()
diff --git a/llm_server/llm/openai/oai_to_vllm.py b/llm_server/llm/openai/oai_to_vllm.py
new file mode 100644
index 0000000..ef07a08
--- /dev/null
+++ b/llm_server/llm/openai/oai_to_vllm.py
@@ -0,0 +1,97 @@
+from flask import jsonify
+
+from llm_server import opts
+
+
+def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
+ if not request_json_body.get('stop'):
+ request_json_body['stop'] = []
+ if not isinstance(request_json_body['stop'], list):
+ # It is a string, so create a list with the existing element.
+ request_json_body['stop'] = [request_json_body['stop']]
+
+ if stop_hashes:
+ if opts.openai_force_no_hashes:
+ request_json_body['stop'].append('###')
+ else:
+ # TODO: make stopping strings a configurable
+ request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT'])
+ else:
+ request_json_body['stop'].extend(['user:', 'assistant:'])
+
+ if request_json_body.get('frequency_penalty', 0) < -2:
+ request_json_body['frequency_penalty'] = -2
+ elif request_json_body.get('frequency_penalty', 0) > 2:
+ request_json_body['frequency_penalty'] = 2
+
+ if mode == 'vllm' and request_json_body.get('top_p') == 0:
+ request_json_body['top_p'] = 0.01
+
+ request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
+ if request_json_body['max_tokens'] == 0:
+ # We don't want to set any defaults here.
+ del request_json_body['max_tokens']
+
+ return request_json_body
+
+
+def format_oai_err(err_msg):
+ print('OAI ERROR MESSAGE:', err_msg)
+ return jsonify({
+ "error": {
+ "message": err_msg,
+ "type": "invalid_request_error",
+ "param": None,
+ "code": None
+ }
+ }), 400
+
+
+def validate_oai(parameters):
+ if parameters.get('messages'):
+ for m in parameters['messages']:
+ if m['role'].lower() not in ['assistant', 'user', 'system']:
+ return format_oai_err('messages role must be assistant, user, or system')
+
+ if parameters.get('temperature', 0) > 2:
+ return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
+ if parameters.get('temperature', 0) < 0:
+ return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
+
+ if parameters.get('top_p', 1) > 2:
+ return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
+ if parameters.get('top_p', 1) < 0:
+ return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
+
+ if parameters.get('presence_penalty', 1) > 2:
+ return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
+ if parameters.get('presence_penalty', 1) < -2:
+ return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
+
+ if parameters.get('top_p', 1) > 2:
+ return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
+ if parameters.get('top_p', 1) < 0:
+ return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
+
+ if parameters.get('top_p', 1) > 2:
+ return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
+ if parameters.get('top_p', 1) < 0:
+ return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
+
+ if parameters.get('max_tokens', 2) < 1:
+ return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'")
+
+
+def return_invalid_model_err(requested_model: str):
+ if requested_model:
+ msg = f"The model `{requested_model}` does not exist"
+ else:
+ msg = "The requested model does not exist"
+ return jsonify({
+ "error": {
+ "message": msg,
+ "type": "invalid_request_error",
+ "param": None,
+ "code": "model_not_found"
+ }
+ }), 404
diff --git a/llm_server/llm/openai/transform.py b/llm_server/llm/openai/transform.py
index d5b64e3..daec3dc 100644
--- a/llm_server/llm/openai/transform.py
+++ b/llm_server/llm/openai/transform.py
@@ -2,86 +2,35 @@ import concurrent.futures
import re
import secrets
import string
-import time
import traceback
from typing import Dict, List
import tiktoken
-from flask import jsonify, make_response
-import llm_server
from llm_server import opts
from llm_server.llm import get_token_count
-from llm_server.routes.cache import redis
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
-def build_openai_response(prompt, response, model=None):
- # Seperate the user's prompt from the context
- x = prompt.split('### USER:')
- if len(x) > 1:
- prompt = re.sub(r'\n$', '', x[-1].strip(' '))
-
- # Make sure the bot doesn't put any other instructions in its response
- # y = response.split('\n### ')
- # if len(y) > 1:
- # response = re.sub(r'\n$', '', y[0].strip(' '))
- response = re.sub(ANTI_RESPONSE_RE, '', response)
- response = re.sub(ANTI_CONTINUATION_RE, '', response)
-
- # TODO: async/await
- prompt_tokens = llm_server.llm.get_token_count(prompt)
- response_tokens = llm_server.llm.get_token_count(response)
- running_model = redis.get('running_model', str, 'ERROR')
-
- response = make_response(jsonify({
- "id": f"chatcmpl-{generate_oai_string(30)}",
- "object": "chat.completion",
- "created": int(time.time()),
- "model": running_model if opts.openai_expose_our_model else model,
- "choices": [{
- "index": 0,
- "message": {
- "role": "assistant",
- "content": response,
- },
- "logprobs": None,
- "finish_reason": "stop"
- }],
- "usage": {
- "prompt_tokens": prompt_tokens,
- "completion_tokens": response_tokens,
- "total_tokens": prompt_tokens + response_tokens
- }
- }), 200)
-
- stats = redis.get('proxy_stats', dict)
- if stats:
- response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
- return response
-
-
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))
-def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
- tokenizer = tiktoken.get_encoding("cl100k_base")
-
- def get_token_count_tiktoken_thread(msg):
- return len(tokenizer.encode(msg["content"]))
+def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
+ def get_token_count_thread(msg):
+ return get_token_count(msg["content"], backend_url)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
- token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
+ token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
- formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
+ formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
# If total tokens exceed the limit, start trimming
- if total_tokens > context_token_limit:
+ if total_tokens + formatting_tokens > context_token_limit:
while True:
while total_tokens + formatting_tokens > context_token_limit:
# Calculate the index to start removing messages from
@@ -94,22 +43,43 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
break
- def get_token_count_thread(msg):
- return get_token_count(msg["content"])
-
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
- formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
-
+ formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
else:
break
+ return prompt
+
+def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
+ tokenizer = tiktoken.get_encoding("cl100k_base")
+ token_count = get_token_count(prompt, backend_url)
+
+ # If total tokens exceed the limit, start trimming
+ if token_count > context_token_limit:
+ while True:
+ while token_count > context_token_limit:
+ # Calculate the index to start removing characters from
+ remove_index = len(prompt) // 3
+
+ while remove_index < len(prompt):
+ prompt = prompt[:remove_index] + prompt[remove_index + 100:]
+ token_count = len(tokenizer.encode(prompt))
+ if token_count <= context_token_limit or remove_index == len(prompt):
+ break
+
+ token_count = get_token_count(prompt, backend_url)
+ if token_count > context_token_limit:
+ # Start over, but this time calculate the token count using the backend
+ token_count = get_token_count(prompt, backend_url)
+ else:
+ break
return prompt
@@ -117,8 +87,9 @@ def transform_messages_to_prompt(oai_messages):
try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in oai_messages:
- if not msg.get('content') or not msg.get('role'):
+ if 'content' not in msg.keys() or 'role' not in msg.keys():
return False
+ msg['content'] = str(msg['content']) # Prevent any weird issues.
if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user':
@@ -126,7 +97,7 @@ def transform_messages_to_prompt(oai_messages):
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else:
- return False
+ raise Exception(f'Unknown role: {msg["role"]}')
except Exception as e:
# TODO: use logging
traceback.print_exc()
diff --git a/llm_server/llm/vllm/generate.py b/llm_server/llm/vllm/generate.py
index 1549f2e..31cd511 100644
--- a/llm_server/llm/vllm/generate.py
+++ b/llm_server/llm/vllm/generate.py
@@ -1,80 +1,21 @@
"""
This file is used by the worker that processes requests.
"""
-import json
-import time
-from uuid import uuid4
import requests
-import llm_server
from llm_server import opts
-from llm_server.routes.cache import redis
# TODO: make the VLMM backend return TPS and time elapsed
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
def prepare_json(json_data: dict):
- # logit_bias is not currently supported
- # del json_data['logit_bias']
-
# Convert back to VLLM.
json_data['max_tokens'] = json_data.pop('max_new_tokens')
return json_data
-def transform_to_text(json_request, api_response):
- """
- This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
- :param json_request:
- :param api_response:
- :return:
- """
- prompt = transform_prompt_to_text(json_request['messages'])
- text = ''
- finish_reason = None
- for line in api_response.split('\n'):
- if line.startswith('data:'):
- try:
- data = json.loads(line[5:].strip())
- except json.decoder.JSONDecodeError:
- break
- if 'choices' in data:
- for choice in data['choices']:
- if 'delta' in choice and 'content' in choice['delta']:
- text += choice['delta']['content']
- if data['choices'][0]['finish_reason']:
- finish_reason = data['choices'][0]['finish_reason']
-
- prompt_tokens = len(llm_server.llm.get_token_count(prompt))
- completion_tokens = len(llm_server.llm.get_token_count(text))
- running_model = redis.get('running_model', str, 'ERROR')
-
- # https://platform.openai.com/docs/api-reference/making-requests?lang=python
- return {
- "id": str(uuid4()),
- "object": "chat.completion",
- "created": int(time.time()),
- "model": running_model,
- "usage": {
- "prompt_tokens": prompt_tokens,
- "completion_tokens": completion_tokens,
- "total_tokens": prompt_tokens + completion_tokens
- },
- "choices": [
- {
- "message": {
- "role": "assistant",
- "content": text
- },
- "finish_reason": finish_reason,
- "index": 0
- }
- ]
- }
-
-
def transform_prompt_to_text(prompt: list):
text = ''
for item in prompt:
@@ -82,26 +23,26 @@ def transform_prompt_to_text(prompt: list):
return text.strip('\n')
-def handle_blocking_request(json_data: dict):
+def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10):
try:
- r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
+ r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
except requests.exceptions.ReadTimeout:
- print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
+ # print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
return False, None, 'Request to backend timed out'
except Exception as e:
- print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
+ # print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False, None, 'Request to backend encountered error'
if r.status_code != 200:
- print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
+ # print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
return False, r, f'Backend returned {r.status_code}'
return True, r, None
-def generate(json_data: dict):
+def generate(json_data: dict, cluster_backend, timeout: int = None):
if json_data.get('stream'):
try:
- return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
+ return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
except Exception as e:
- print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
+ return False
else:
- return handle_blocking_request(json_data)
+ return handle_blocking_request(json_data, cluster_backend, timeout=timeout)
diff --git a/llm_server/llm/vllm/info.py b/llm_server/llm/vllm/info.py
index 996c614..0142301 100644
--- a/llm_server/llm/vllm/info.py
+++ b/llm_server/llm/vllm/info.py
@@ -1,3 +1,7 @@
+import requests
+
+from llm_server import opts
+
vllm_info = """
Important: This endpoint is running vllm and not all Oobabooga parameters are supported.
Supported Parameters:
"""
\ No newline at end of file
+"""
diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py
index a698fd6..bdb6650 100644
--- a/llm_server/llm/vllm/tokenize.py
+++ b/llm_server/llm/vllm/tokenize.py
@@ -1,26 +1,51 @@
+import concurrent.futures
+
import requests
import tiktoken
from llm_server import opts
+from llm_server.cluster.cluster_config import cluster_config
+from llm_server.logging import create_logger
-def tokenize(prompt: str) -> int:
+def tokenize(prompt: str, backend_url: str) -> int:
+ assert backend_url
+ assert isinstance(backend_url, str)
+
if not prompt:
# The tokenizers have issues when the prompt is None.
return 0
+ assert isinstance(prompt, str)
+
+ logger = create_logger('tokenizer')
+
+ # The backend could have died between when the request was
+ # submitted and now, so let's double check it's still online.
+ backend_url = cluster_config.validate_backend(backend_url)
+
tokenizer = tiktoken.get_encoding("cl100k_base")
- # First we tokenize it locally to determine if it's worth sending it to the backend.
- initial_estimate = len(tokenizer.encode(prompt))
- if initial_estimate <= opts.context_size + 200:
+ # Split the prompt into 2000 character chunks
+ chunk_size = 2000
+ chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)]
+
+ # Define a function to send a chunk to the server
+ def send_chunk(chunk):
try:
- r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
+ r = requests.post(f'{backend_url}/tokenize', json={'input': chunk}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
j = r.json()
return j['length']
except Exception as e:
- print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
- return len(tokenizer.encode(prompt)) + 10
- else:
- # If the result was greater than our context size, return the estimate.
- # We won't be sending it through the backend so it does't need to be accurage.
- return initial_estimate
+ logger.debug(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
+ return len(tokenizer.encode(chunk)) + 10
+
+ # Use a ThreadPoolExecutor to send all chunks to the server at once
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ future_to_chunk = {executor.submit(send_chunk, chunk): chunk for chunk in chunks}
+ for future in concurrent.futures.as_completed(future_to_chunk):
+ chunk = future_to_chunk[future]
+ try:
+ data = future.result()
+ except Exception as exc:
+ logger.warning('%r generated an exception: %s' % (chunk, exc))
+ return sum(future.result() for future in future_to_chunk)
diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py
index e5b0fad..5c12b45 100644
--- a/llm_server/llm/vllm/vllm_backend.py
+++ b/llm_server/llm/vllm/vllm_backend.py
@@ -1,10 +1,9 @@
-import threading
from typing import Tuple
from flask import jsonify
from vllm import SamplingParams
-from llm_server.database.database import log_prompt
+from llm_server.database.log_to_db import log_to_db
from llm_server.llm.llm_backend import LLMBackend
@@ -19,16 +18,8 @@ class VLLMBackend(LLMBackend):
# Failsafe
backend_response = ''
- r_url = request.url
-
- def background_task():
- log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
- response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
-
- # TODO: use async/await instead of threads
- thread = threading.Thread(target=background_task)
- thread.start()
- thread.join()
+ log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
+ response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
return jsonify({'results': [{'text': backend_response}]}), 200
@@ -38,14 +29,20 @@ class VLLMBackend(LLMBackend):
top_k = parameters.get('top_k', self._default_params['top_k'])
if top_k <= 0:
top_k = -1
+
+ # TODO: support more params
sampling_params = SamplingParams(
temperature=parameters.get('temperature', self._default_params['temperature']),
top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=top_k,
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
- stop=parameters.get('stopping_strings', self._default_params['stop']),
+ stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))),
ignore_eos=parameters.get('ban_eos_token', False),
- max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
+ max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
+ presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
+ frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']),
+ length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']),
+ early_stopping=parameters.get('early_stopping', self._default_params['early_stopping'])
)
except ValueError as e:
return None, str(e).strip('.')
diff --git a/llm_server/logging.py b/llm_server/logging.py
new file mode 100644
index 0000000..7e9aa74
--- /dev/null
+++ b/llm_server/logging.py
@@ -0,0 +1,52 @@
+import logging
+
+import coloredlogs
+
+from llm_server import opts
+
+
+class LoggingInfo:
+ def __init__(self):
+ self._level = logging.INFO
+ self._format = opts.LOGGING_FORMAT
+
+ @property
+ def level(self):
+ return self._level
+
+ @level.setter
+ def level(self, value):
+ self._level = value
+
+ @property
+ def format(self):
+ return self._format
+
+ @format.setter
+ def format(self, value):
+ self._format = value
+
+
+logging_info = LoggingInfo()
+
+
+def init_logging():
+ """
+ Set up the parent logger.
+ :return:
+ """
+ logger = logging.getLogger('llm_server')
+ logger.setLevel(logging_info.level)
+
+
+def create_logger(name):
+ logger = logging.getLogger('llm_server').getChild(name)
+ logger.setLevel(logging_info.level)
+ if not logger.handlers:
+ handler = logging.StreamHandler()
+ handler.setLevel(logging_info.level)
+ formatter = logging.Formatter(logging_info.format)
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+ coloredlogs.install(logger=logger, level=logging_info.level)
+ return logger
diff --git a/llm_server/messages.py b/llm_server/messages.py
new file mode 100644
index 0000000..c7e3eb7
--- /dev/null
+++ b/llm_server/messages.py
@@ -0,0 +1 @@
+BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
diff --git a/llm_server/netdata.py b/llm_server/netdata.py
deleted file mode 100644
index f37c109..0000000
--- a/llm_server/netdata.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import json
-from datetime import datetime, timedelta
-
-import requests
-
-from llm_server import opts
-
-
-def get_power_states():
- gpu_num = 0
- output = {}
- while True:
- url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state"
- try:
- response = requests.get(url, timeout=10)
- if response.status_code != 200:
- break
- data = json.loads(response.text)
- power_state_data = data['data'][0]
- power_state = None
- for i in range(1, len(power_state_data)):
- if power_state_data[i] == 1:
- power_state = data['labels'][i]
- break
- output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p'))
- except Exception as e:
- print('Failed to fetch Netdata metrics:', e)
- return output
- gpu_num += 1
- return output
-
-
-def get_gpu_wh(gpu_id: int):
- chart_name = f"nvidia_smi.gpu{gpu_id}_power"
- now = datetime.now()
- one_hour_ago = now - timedelta(hours=1)
- num_seconds = int((now - one_hour_ago).total_seconds())
- params = {
- "chart": chart_name,
- "after": int(one_hour_ago.timestamp()),
- "before": int(now.timestamp()),
- "points": num_seconds,
- "group": "second",
- "format": "json",
- "options": "absolute|jsonwrap"
- }
- response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10)
- data = json.loads(response.text)
- total_power_usage_watts = sum(point[1] for point in data['result']['data'])
- # total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1)
- total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3)
- return total_power_usage_kwh
diff --git a/llm_server/opts.py b/llm_server/opts.py
index de23c7a..f75ba94 100644
--- a/llm_server/opts.py
+++ b/llm_server/opts.py
@@ -1,12 +1,11 @@
# Read-only global variables
+# Uppercase variables are read-only globals.
+# Lowercase variables are ones that are set on startup and are never changed.
+
# TODO: rewrite the config system so I don't have to add every single config default here
-running_model = 'ERROR'
-concurrent_gens = 3
-mode = 'oobabooga'
-backend_url = None
-context_size = 5555
+frontend_api_mode = 'ooba'
max_new_tokens = 500
auth_required = False
log_prompts = False
@@ -33,7 +32,15 @@ openai_expose_our_model = False
openai_force_no_hashes = True
include_system_tokens_in_stats = True
openai_moderation_scan_last_n = 5
-openai_moderation_workers = 10
openai_org_name = 'OpenAI'
openai_silent_trim = False
openai_moderation_enabled = True
+cluster = {}
+show_backends = True
+background_homepage_cacher = True
+openai_moderation_timeout = 5
+prioritize_by_size = False
+cluster_workers = 0
+redis_stream_timeout = 25000
+
+LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
diff --git a/llm_server/pre_fork.py b/llm_server/pre_fork.py
index f3ea0f4..6e8c1ad 100644
--- a/llm_server/pre_fork.py
+++ b/llm_server/pre_fork.py
@@ -1,21 +1,9 @@
import sys
-from redis import Redis
-
-from llm_server.routes.cache import redis
-from llm_server.routes.v1.generate_stats import generate_stats
+from llm_server.custom_redis import redis
def server_startup(s):
- if not redis.get('daemon_started', bool):
+ if not redis.get('daemon_started', dtype=bool):
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
sys.exit(1)
-
- # Flush the RedisPriorityQueue database.
- queue_redis = Redis(host='localhost', port=6379, db=15)
- for key in queue_redis.scan_iter('*'):
- queue_redis.delete(key)
-
- # Cache the initial stats
- print('Loading backend stats...')
- generate_stats()
diff --git a/llm_server/routes/helpers/client.py b/llm_server/routes/helpers/client.py
index 48e721e..5031b8b 100644
--- a/llm_server/routes/helpers/client.py
+++ b/llm_server/routes/helpers/client.py
@@ -1,11 +1,18 @@
-from llm_server import opts
-from llm_server.routes.cache import redis
+from llm_server.cluster.cluster_config import cluster_config
+from llm_server.custom_redis import redis
-def format_sillytavern_err(msg: str, level: str = 'info'):
- http_host = redis.get('http_host', str)
+def format_sillytavern_err(msg: str, backend_url: str = None, error_type: str = 'info'):
+ if backend_url:
+ cluster_backend_hash = cluster_config.get_backend(backend_url)['hash']
+ else:
+ cluster_backend_hash = 'none'
+ http_host = redis.get('http_host', dtype=str)
return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
--> {level.upper()} <-
+-> {error_type.upper()} <-
{msg}
+```
+```
+BACKEND: {cluster_backend_hash}
```"""
diff --git a/llm_server/routes/helpers/http.py b/llm_server/routes/helpers/http.py
index 2fa1190..a3f1906 100644
--- a/llm_server/routes/helpers/http.py
+++ b/llm_server/routes/helpers/http.py
@@ -100,4 +100,4 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas
j = json.loads(str(data))
return True, j
except Exception as e:
- return False, e
+ return False, e
\ No newline at end of file
diff --git a/llm_server/routes/helpers/model.py b/llm_server/routes/helpers/model.py
new file mode 100644
index 0000000..bf18b66
--- /dev/null
+++ b/llm_server/routes/helpers/model.py
@@ -0,0 +1,15 @@
+def estimate_model_size(config: dict):
+ """
+ Estimate the size of a model from its config. No idea if this is correct,
+ but it allows us to compare models.
+ :param config:
+ :return:
+ """
+ vocab_size = config.get('vocab_size')
+ hidden_size = config.get('hidden_size')
+ num_hidden_layers = config.get('num_hidden_layers')
+ intermediate_size = config.get('intermediate_size')
+ if vocab_size and hidden_size and num_hidden_layers and intermediate_size:
+ total_params = (vocab_size * hidden_size) + (num_hidden_layers * ((hidden_size * intermediate_size * 4) + (hidden_size * hidden_size * 3)))
+ return int(total_params / 1e9)
+ return 0
diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py
index d6b02e2..aadda78 100644
--- a/llm_server/routes/ooba_request_handler.py
+++ b/llm_server/routes/ooba_request_handler.py
@@ -3,8 +3,8 @@ from typing import Tuple
import flask
from flask import jsonify, request
-from llm_server import opts
-from llm_server.database.database import log_prompt
+from llm_server import messages, opts
+from llm_server.database.log_to_db import log_to_db
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler
@@ -13,8 +13,11 @@ class OobaRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- def handle_request(self):
+ def handle_request(self, return_ok: bool = True):
assert not self.used
+ if self.offline:
+ print('This backend is offline:', messages.BACKEND_OFFLINE)
+ return self.handle_error(messages.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request()
if not request_valid:
@@ -25,14 +28,19 @@ class OobaRequestHandler(RequestHandler):
llm_request = {**self.parameters, 'prompt': prompt}
_, backend_response = self.generate_response(llm_request)
- return backend_response
+ if return_ok:
+ # Always return 200 so ST displays our error messages
+ return backend_response[0], 200
+ else:
+ # The OpenAI route needs to detect 429 errors.
+ return backend_response
def handle_ratelimited(self, do_log: bool = True):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg)
if do_log:
- log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
- return backend_response[0], 200 # We only return the response from handle_error(), not the error code
+ log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
+ return backend_response[0], 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
@@ -40,7 +48,7 @@ class OobaRequestHandler(RequestHandler):
# TODO: how to format this
response_msg = error_msg
else:
- response_msg = format_sillytavern_err(error_msg, error_type)
+ response_msg = format_sillytavern_err(error_msg, error_type=error_type, backend_url=self.backend_url)
return jsonify({
'results': [{'text': response_msg}]
diff --git a/llm_server/routes/openai/__init__.py b/llm_server/routes/openai/__init__.py
index 67febc9..c3837e4 100644
--- a/llm_server/routes/openai/__init__.py
+++ b/llm_server/routes/openai/__init__.py
@@ -5,9 +5,11 @@ from ..server_error import handle_server_error
from ... import opts
openai_bp = Blueprint('openai/v1/', __name__)
+openai_model_bp = Blueprint('openai/', __name__)
@openai_bp.before_request
+@openai_model_bp.before_request
def before_oai_request():
if not opts.enable_openi_compatible_backend:
return 'The OpenAI-compatible backend is disabled.', 401
@@ -15,8 +17,22 @@ def before_oai_request():
@openai_bp.errorhandler(500)
+@openai_model_bp.errorhandler(500)
def handle_error(e):
- return handle_server_error(e)
+ """
+ Found Codes:
+ "auth_subrequest_error"
+ """
+
+ print('OAI returning error:', e)
+ return jsonify({
+ "error": {
+ "message": "Internal server error",
+ "type": "auth_subrequest_error",
+ "param": None,
+ "code": "internal_error"
+ }
+ }), 500
from .models import openai_list_models
diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py
index cc27dce..9ccc15f 100644
--- a/llm_server/routes/openai/chat_completions.py
+++ b/llm_server/routes/openai/chat_completions.py
@@ -1,113 +1,175 @@
import json
-import threading
import time
import traceback
+import ujson
from flask import Response, jsonify, request
+from redis import Redis
-from . import openai_bp
-from ..cache import redis
+from llm_server.custom_redis import redis
+from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler
+from ..queue import priority_queue
from ... import opts
-from ...database.database import log_prompt
-from ...llm.generator import generator
-from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
-from ...llm.vllm import tokenize
+from ...database.log_to_db import log_to_db
+from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
+from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
# TODO: add rate-limit headers?
+
@openai_bp.route('/chat/completions', methods=['POST'])
-def openai_chat_completions():
+@openai_model_bp.route('//v1/chat/completions', methods=['POST'])
+def openai_chat_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else:
- handler = OpenAIRequestHandler(request, request_json_body)
- if request_json_body.get('stream'):
- if not opts.enable_streaming:
- # TODO: return a proper OAI error message
- return 'disabled', 401
+ handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
+ if handler.offline:
+ return return_invalid_model_err(model_name)
- if opts.mode != 'vllm':
- # TODO: implement other backends
- raise NotImplementedError
-
- response_status_code = 0
- start_time = time.time()
- request_valid, invalid_response = handler.validate_request()
- if not request_valid:
- # TODO: simulate OAI here
- raise Exception('TODO: simulate OAI here')
- else:
- handler.prompt = transform_messages_to_prompt(request_json_body['messages'])
- msg_to_backend = {
- **handler.parameters,
- 'prompt': handler.prompt,
- 'stream': True,
- }
- try:
- response = generator(msg_to_backend)
- r_headers = dict(request.headers)
- r_url = request.url
- model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
- oai_string = generate_oai_string(30)
-
- def generate():
- generated_text = ''
- partial_response = b''
- for chunk in response.iter_content(chunk_size=1):
- partial_response += chunk
- if partial_response.endswith(b'\x00'):
- json_strs = partial_response.split(b'\x00')
- for json_str in json_strs:
- if json_str:
- try:
- json_obj = json.loads(json_str.decode())
- new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
- generated_text = generated_text + new
- except IndexError:
- # ????
- continue
-
- data = {
- "id": f"chatcmpl-{oai_string}",
- "object": "chat.completion.chunk",
- "created": int(time.time()),
- "model": model,
- "choices": [
- {
- "index": 0,
- "delta": {
- "content": new
- },
- "finish_reason": None
- }
- ]
- }
- yield f'data: {json.dumps(data)}\n\n'
-
- yield 'data: [DONE]\n\n'
- end_time = time.time()
- elapsed_time = end_time - start_time
-
- def background_task():
- generated_tokens = tokenize(generated_text)
- log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
-
- # TODO: use async/await instead of threads
- thread = threading.Thread(target=background_task)
- thread.start()
- thread.join()
-
- return Response(generate(), mimetype='text/event-stream')
- except:
- # TODO: simulate OAI here
- raise Exception
- else:
+ if not request_json_body.get('stream'):
try:
return handler.handle_request()
except Exception:
traceback.print_exc()
return 'Internal server error', 500
+ else:
+ if not opts.enable_streaming:
+ return 'Streaming disabled', 403
+
+ invalid_oai_err_msg = validate_oai(handler.request_json_body)
+ if invalid_oai_err_msg:
+ return invalid_oai_err_msg
+
+ handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
+
+ handler.parameters, e = handler.get_parameters()
+ handler.request_json_body = {
+ 'messages': handler.request_json_body['messages'],
+ 'model': handler.request_json_body['model'],
+ **handler.parameters
+ }
+
+ if opts.openai_silent_trim:
+ handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
+ else:
+ handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
+ if not handler.prompt:
+ # Prevent issues on the backend.
+ return 'Invalid prompt', 400
+
+ # Need to set the prompt in the JSON body since that's what the inference worker expects.
+ handler.request_json_body['prompt'] = handler.prompt
+
+ start_time = time.time()
+
+ request_valid, invalid_response = handler.validate_request()
+ if not request_valid:
+ return invalid_response
+
+ event = None
+ if not handler.is_client_ratelimited():
+ event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
+ if not event:
+ log_to_db(
+ handler.client_ip,
+ handler.token,
+ handler.prompt,
+ None,
+ None,
+ handler.parameters,
+ request.headers,
+ 429,
+ request.url,
+ handler.backend_url,
+ )
+ return handler.handle_ratelimited()
+
+ try:
+ r_headers = dict(request.headers)
+ r_url = request.url
+ model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
+ oai_string = generate_oai_string(30)
+
+ # Need to do this before we enter generate() since we want to be able to
+ # return a 408 if necessary.
+ _, stream_name, error_msg = event.wait()
+ if error_msg:
+ print('OAI failed to start streaming:', error_msg)
+ stream_name = None # set to null so that the Finally ignores it.
+ return 'Request Timeout', 408
+
+ def generate():
+ stream_redis = Redis(db=8)
+ generated_text = ''
+ try:
+ last_id = '0-0'
+ while True:
+ stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
+ if not stream_data:
+ print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
+ yield 'data: [DONE]\n\n'
+ else:
+ for stream_index, item in stream_data[0][1]:
+ last_id = stream_index
+ timestamp = int(stream_index.decode('utf-8').split('-')[0])
+ data = ujson.loads(item[b'data'])
+ if data['error']:
+ # Not printing error since we can just check the daemon log.
+ print('OAI streaming encountered error')
+ yield 'data: [DONE]\n\n'
+ return
+ elif data['new']:
+ response = {
+ "id": f"chatcmpl-{oai_string}",
+ "object": "chat.completion.chunk",
+ "created": timestamp,
+ "model": model,
+ "choices": [
+ {
+ "index": 0,
+ "delta": {
+ "content": data['new']
+ },
+ "finish_reason": None
+ }
+ ]
+ }
+ generated_text = generated_text + data['new']
+ yield f'data: {json.dumps(response)}\n\n'
+ elif data['completed']:
+ yield 'data: [DONE]\n\n'
+ end_time = time.time()
+ elapsed_time = end_time - start_time
+ log_to_db(
+ handler.client_ip,
+ handler.token,
+ handler.prompt,
+ generated_text,
+ elapsed_time,
+ handler.parameters,
+ r_headers,
+ 200,
+ r_url,
+ handler.backend_url,
+ )
+ return
+ except GeneratorExit:
+ return
+ except Exception:
+ traceback.print_exc()
+ yield 'data: [DONE]\n\n'
+ finally:
+ if event:
+ redis.publish(f'notifications:{event.event_id}', 'canceled')
+ if stream_name:
+ stream_redis.delete(stream_name)
+
+ return Response(generate(), mimetype='text/event-stream')
+ except Exception:
+ traceback.print_exc()
+ return 'INTERNAL SERVER', 500
diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py
index 503f628..2524b17 100644
--- a/llm_server/routes/openai/completions.py
+++ b/llm_server/routes/openai/completions.py
@@ -1,38 +1,68 @@
import time
import traceback
-from flask import jsonify, make_response, request
+import simplejson as json
+import ujson
+from flask import Response, jsonify, request
+from redis import Redis
-from . import openai_bp
-from ..cache import redis
-from ..helpers.client import format_sillytavern_err
+from llm_server.custom_redis import redis
+from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
+from ..queue import priority_queue
from ... import opts
+from ...database.log_to_db import log_to_db
from ...llm import get_token_count
-from ...llm.openai.transform import build_openai_response, generate_oai_string
+from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
+from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
# TODO: add rate-limit headers?
@openai_bp.route('/completions', methods=['POST'])
-def openai_completions():
+@openai_model_bp.route('//v1/completions', methods=['POST'])
+def openai_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
- try:
- response, status_code = OobaRequestHandler(request).handle_request()
- if status_code != 200:
- return status_code
+ handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
+ if handler.offline:
+ return return_invalid_model_err(model_name)
+
+ if handler.cluster_backend_info['mode'] != 'vllm':
+ # TODO: implement other backends
+ raise NotImplementedError
+
+ invalid_oai_err_msg = validate_oai(handler.request_json_body)
+ if invalid_oai_err_msg:
+ return invalid_oai_err_msg
+ handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
+
+ if opts.openai_silent_trim:
+ handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
+ else:
+ # The handle_request() call below will load the prompt so we don't have
+ # to do anything else here.
+ pass
+
+ handler.request_json_body['prompt'] = handler.prompt
+
+ if not request_json_body.get('stream'):
+ invalid_oai_err_msg = validate_oai(request_json_body)
+ if invalid_oai_err_msg:
+ return invalid_oai_err_msg
+ response, status_code = handler.handle_request(return_ok=False)
+ if status_code == 429:
+ return handler.handle_ratelimited()
output = response.json['results'][0]['text']
- # TODO: async/await
- prompt_tokens = get_token_count(request_json_body['prompt'])
- response_tokens = get_token_count(output)
- running_model = redis.get('running_model', str, 'ERROR')
+ prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
+ response_tokens = get_token_count(output, handler.backend_url)
+ running_model = redis.get('running_model', 'ERROR', dtype=str)
- response = make_response(jsonify({
+ response = jsonify({
"id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion",
"created": int(time.time()),
@@ -42,7 +72,7 @@ def openai_completions():
"text": output,
"index": 0,
"logprobs": None,
- "finish_reason": None
+ "finish_reason": "stop"
}
],
"usage": {
@@ -50,12 +80,141 @@ def openai_completions():
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
- }), 200)
+ })
- stats = redis.get('proxy_stats', dict)
- if stats:
- response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
- return response
- except Exception:
- traceback.print_exc()
- return 'Internal Server Error', 500
+ # TODO:
+ # stats = redis.get('proxy_stats', dtype=dict)
+ # if stats:
+ # response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
+ return response, 200
+ else:
+ if not opts.enable_streaming:
+ return 'Streaming disabled', 403
+
+ request_valid, invalid_response = handler.validate_request()
+ if not request_valid:
+ return invalid_response
+
+ handler.parameters, _ = handler.get_parameters()
+ handler.request_json_body = {
+ 'prompt': handler.request_json_body['prompt'],
+ 'model': handler.request_json_body['model'],
+ **handler.parameters
+ }
+
+ invalid_oai_err_msg = validate_oai(handler.request_json_body)
+ if invalid_oai_err_msg:
+ return invalid_oai_err_msg
+
+ if opts.openai_silent_trim:
+ handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
+ if not handler.prompt:
+ # Prevent issues on the backend.
+ return 'Invalid prompt', 400
+
+ start_time = time.time()
+
+ request_valid, invalid_response = handler.validate_request()
+ if not request_valid:
+ return invalid_response
+
+ event = None
+ if not handler.is_client_ratelimited():
+ event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
+ if not event:
+ log_to_db(
+ handler.client_ip,
+ handler.token,
+ handler.prompt,
+ None,
+ None,
+ handler.parameters,
+ request.headers,
+ 429,
+ request.url,
+ handler.backend_url,
+ )
+ return handler.handle_ratelimited()
+
+ try:
+ r_headers = dict(request.headers)
+ r_url = request.url
+ model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
+ oai_string = generate_oai_string(30)
+
+ _, stream_name, error_msg = event.wait()
+ if error_msg:
+ print('OAI failed to start streaming:', error_msg)
+ stream_name = None
+ return 'Request Timeout', 408
+
+ def generate():
+ stream_redis = Redis(db=8)
+ generated_text = ''
+ try:
+ last_id = '0-0'
+ while True:
+ stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
+ if not stream_data:
+ print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
+ yield 'data: [DONE]\n\n'
+ else:
+ for stream_index, item in stream_data[0][1]:
+ last_id = stream_index
+ timestamp = int(stream_index.decode('utf-8').split('-')[0])
+ data = ujson.loads(item[b'data'])
+ if data['error']:
+ print('OAI streaming encountered error')
+ yield 'data: [DONE]\n\n'
+ return
+ elif data['new']:
+ response = {
+ "id": f"cmpl-{oai_string}",
+ "object": "text_completion",
+ "created": timestamp,
+ "model": model,
+ "choices": [
+ {
+ "index": 0,
+ "delta": {
+ "content": data['new']
+ },
+ "finish_reason": None
+ }
+ ]
+ }
+ generated_text = generated_text + data['new']
+ yield f'data: {json.dumps(response)}\n\n'
+ elif data['completed']:
+ yield 'data: [DONE]\n\n'
+ end_time = time.time()
+ elapsed_time = end_time - start_time
+ log_to_db(
+ handler.client_ip,
+ handler.token,
+ handler.prompt,
+ generated_text,
+ elapsed_time,
+ handler.parameters,
+ r_headers,
+ 200,
+ r_url,
+ handler.backend_url,
+ )
+ return
+ except GeneratorExit:
+ # This should be triggered if a client disconnects early.
+ return
+ except Exception:
+ traceback.print_exc()
+ yield 'data: [DONE]\n\n'
+ finally:
+ if event:
+ redis.publish(f'notifications:{event.event_id}', 'canceled')
+ if stream_name:
+ stream_redis.delete(stream_name)
+
+ return Response(generate(), mimetype='text/event-stream')
+ except Exception:
+ traceback.print_exc()
+ return 'INTERNAL SERVER', 500
diff --git a/llm_server/routes/openai/info.py b/llm_server/routes/openai/info.py
index 54959ae..4fc578a 100644
--- a/llm_server/routes/openai/info.py
+++ b/llm_server/routes/openai/info.py
@@ -1,7 +1,7 @@
from flask import Response
from . import openai_bp
-from ..cache import flask_cache
+from llm_server.custom_redis import flask_cache
from ... import opts
diff --git a/llm_server/routes/openai/models.py b/llm_server/routes/openai/models.py
index 47223e7..2ff0629 100644
--- a/llm_server/routes/openai/models.py
+++ b/llm_server/routes/openai/models.py
@@ -3,59 +3,58 @@ import traceback
import requests
from flask import jsonify
+from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from . import openai_bp
-from ..cache import ONE_MONTH_SECONDS, flask_cache, redis
from ..stats import server_start_time
from ... import opts
+from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
from ...helpers import jsonify_pretty
-from ...llm.info import get_running_model
+from ...llm.openai.transform import generate_oai_string
@openai_bp.route('/models', methods=['GET'])
@flask_cache.cached(timeout=60, query_string=True)
def openai_list_models():
- model, error = get_running_model()
- if not model:
+ model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
+ if not model_name:
response = jsonify({
'code': 502,
'msg': 'failed to reach backend',
- 'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
- running_model = redis.get('running_model', str, 'ERROR')
+ running_model = redis.get('running_model', 'ERROR', dtype=str)
oai = fetch_openai_models()
- r = []
+ r = {
+ "object": "list",
+ "data": oai
+ }
+ # TODO: verify this works
if opts.openai_expose_our_model:
- r = [{
- "object": "list",
- "data": [
+ r["data"].insert(0, {
+ "id": running_model,
+ "object": "model",
+ "created": int(server_start_time.timestamp()),
+ "owned_by": opts.llm_middleware_name,
+ "permission": [
{
"id": running_model,
- "object": "model",
+ "object": "model_permission",
"created": int(server_start_time.timestamp()),
- "owned_by": opts.llm_middleware_name,
- "permission": [
- {
- "id": running_model,
- "object": "model_permission",
- "created": int(server_start_time.timestamp()),
- "allow_create_engine": False,
- "allow_sampling": False,
- "allow_logprobs": False,
- "allow_search_indices": False,
- "allow_view": True,
- "allow_fine_tuning": False,
- "organization": "*",
- "group": None,
- "is_blocking": False
- }
- ],
- "root": None,
- "parent": None
+ "allow_create_engine": False,
+ "allow_sampling": False,
+ "allow_logprobs": False,
+ "allow_search_indices": False,
+ "allow_view": True,
+ "allow_fine_tuning": False,
+ "organization": "*",
+ "group": None,
+ "is_blocking": False
}
- ]
- }]
- response = jsonify_pretty(r + oai), 200
+ ],
+ "root": None,
+ "parent": None
+ })
+ response = jsonify_pretty(r), 200
return response
@@ -64,7 +63,14 @@ def fetch_openai_models():
if opts.openai_api_key:
try:
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
- return response.json()['data']
+ j = response.json()['data']
+
+ # The "modelperm" string appears to be user-specific, so we'll
+ # randomize it just to be safe.
+ for model in range(len(j)):
+ for p in range(len(j[model]['permission'])):
+ j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}'
+ return j
except:
traceback.print_exc()
return []
diff --git a/llm_server/routes/openai/simulated.py b/llm_server/routes/openai/simulated.py
index f626490..2dafedb 100644
--- a/llm_server/routes/openai/simulated.py
+++ b/llm_server/routes/openai/simulated.py
@@ -1,7 +1,7 @@
from flask import jsonify
from . import openai_bp
-from ..cache import ONE_MONTH_SECONDS, flask_cache
+from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache
from ...llm.openai.transform import generate_oai_string
from ..stats import server_start_time
@@ -17,7 +17,7 @@ def openai_organizations():
"id": f"org-{generate_oai_string(24)}",
"created": int(server_start_time.timestamp()),
"title": "Personal",
- "name": "user-abcdefghijklmnopqrstuvwx",
+ "name": f"user-{generate_oai_string(24)}",
"description": "Personal org for bobjoe@0.0.0.0",
"personal": True,
"is_default": True,
diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py
index d97ea09..170eb77 100644
--- a/llm_server/routes/openai_request_handler.py
+++ b/llm_server/routes/openai_request_handler.py
@@ -1,14 +1,21 @@
import json
+import re
+import time
import traceback
from typing import Tuple
from uuid import uuid4
import flask
-from flask import jsonify
+from flask import Response, jsonify, make_response
from llm_server import opts
+from llm_server.cluster.backend import get_model_choices
+from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated
-from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
+from llm_server.database.log_to_db import log_to_db
+from llm_server.llm import get_token_count
+from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
+from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results
@@ -20,20 +27,37 @@ class OpenAIRequestHandler(RequestHandler):
def handle_request(self) -> Tuple[flask.Response, int]:
assert not self.used
+ if self.offline:
+ msg = return_invalid_model_err(self.selected_model)
+ print('OAI Offline:', msg)
+ return self.handle_error(msg)
if opts.openai_silent_trim:
- oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
+ oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
else:
oai_messages = self.request.json['messages']
self.prompt = transform_messages_to_prompt(oai_messages)
+ self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
+
request_valid, invalid_response = self.validate_request()
if not request_valid:
return invalid_response
- if opts.openai_api_key and is_api_key_moderated(self.token):
+ if not self.prompt:
+ # TODO: format this as an openai error message
+ return Response('Invalid prompt'), 400
+
+ # TODO: support Ooba backend
+ self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
+
+ invalid_oai_err_msg = validate_oai(self.request_json_body)
+ if invalid_oai_err_msg:
+ return invalid_oai_err_msg
+
+ if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
try:
- # Gather the last message from the user and all preceeding system messages
+ # Gather the last message from the user and all preceding system messages
msg_l = self.request.json['messages'].copy()
msg_l.reverse()
tag = uuid4()
@@ -49,33 +73,40 @@ class OpenAIRequestHandler(RequestHandler):
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
except Exception as e:
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
- print(traceback.format_exc())
-
- # Reconstruct the request JSON with the validated parameters and prompt.
- self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
- if opts.openai_force_no_hashes:
- self.parameters['stop'].append('### ')
-
- if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
- self.request_json_body['top_p'] = 0.01
+ traceback.print_exc()
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
-
model = self.request_json_body.get('model')
if success:
- return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
+ return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
else:
return backend_response, backend_response_status_code
def handle_ratelimited(self, do_log: bool = True):
- # TODO: return a simulated OpenAI error message
- # Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.
- return 'Ratelimited', 429
+ model_choices, default_model = get_model_choices()
+ default_model_info = model_choices[default_model]
+ w = int(default_model_info['estimated_wait']) if default_model_info['estimated_wait'] > 0 else 2
+ response = jsonify({
+ "error": {
+ "message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.",
+ "type": "rate_limit_exceeded",
+ "param": None,
+ "code": None
+ }
+ })
+ response.headers['x-ratelimit-limit-requests'] = '2'
+ response.headers['x-ratelimit-remaining-requests'] = '0'
+ response.headers['x-ratelimit-reset-requests'] = f"{w}s"
+
+ if do_log:
+ log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
+
+ return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
- # TODO: return a simulated OpenAI error message
+ print('OAI Error:', error_msg)
return jsonify({
"error": {
"message": "Invalid request, check your parameters and try again.",
@@ -84,3 +115,51 @@ class OpenAIRequestHandler(RequestHandler):
"code": None
}
}), 400
+
+ def build_openai_response(self, prompt, response, model=None):
+ # Seperate the user's prompt from the context
+ x = prompt.split('### USER:')
+ if len(x) > 1:
+ prompt = re.sub(r'\n$', '', x[-1].strip(' '))
+
+ # Make sure the bot doesn't put any other instructions in its response
+ response = re.sub(ANTI_RESPONSE_RE, '', response)
+ response = re.sub(ANTI_CONTINUATION_RE, '', response)
+
+ prompt_tokens = get_token_count(prompt, self.backend_url)
+ response_tokens = get_token_count(response, self.backend_url)
+ running_model = redis.get('running_model', 'ERROR', dtype=str)
+
+ response = make_response(jsonify({
+ "id": f"chatcmpl-{generate_oai_string(30)}",
+ "object": "chat.completion",
+ "created": int(time.time()),
+ "model": running_model if opts.openai_expose_our_model else model,
+ "choices": [{
+ "index": 0,
+ "message": {
+ "role": "assistant",
+ "content": response,
+ },
+ "logprobs": None,
+ "finish_reason": "stop"
+ }],
+ "usage": {
+ "prompt_tokens": prompt_tokens,
+ "completion_tokens": response_tokens,
+ "total_tokens": prompt_tokens + response_tokens
+ }
+ }), 200)
+ return response
+
+ def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
+ self.parameters, parameters_invalid_msg = self.get_parameters()
+ if not self.parameters:
+ print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
+ return False, (Response('Invalid request, check your parameters and try again.'), 400)
+ invalid_oai_err_msg = validate_oai(self.parameters)
+ if invalid_oai_err_msg:
+ return False, invalid_oai_err_msg
+ # self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
+ # If the parameters were invalid, let the superclass deal with it.
+ return super().validate_request(prompt, do_log)
diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py
index 84cc614..ee66580 100644
--- a/llm_server/routes/queue.py
+++ b/llm_server/routes/queue.py
@@ -1,12 +1,15 @@
-import json
import pickle
import time
+from typing import Tuple
from uuid import uuid4
+import ujson as json
from redis import Redis
from llm_server import opts
-from llm_server.routes.cache import redis
+from llm_server.cluster.cluster_config import cluster_config
+from llm_server.custom_redis import RedisCustom, redis
+from llm_server.database.database import get_token_ratelimit
def increment_ip_count(client_ip: str, redis_key):
@@ -20,24 +23,30 @@ def decrement_ip_count(client_ip: str, redis_key):
class RedisPriorityQueue:
- def __init__(self):
- self.redis = Redis(host='localhost', port=6379, db=15)
- self.pubsub = self.redis.pubsub()
- self.pubsub.subscribe('events')
+ """
+ A queue for a specific backend.
+ """
- def put(self, item, priority):
- event = DataEvent()
+ def __init__(self, name, db: int = 12):
+ self.name = name
+ self.redis = RedisCustom(name, db=db)
+
+ def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
+ # TODO: remove this when we're sure nothing strange is happening
+ assert item is not None
+ assert priority is not None
+ assert selected_model is not None
# Check if the IP is already in the dictionary and if it has reached the limit
- ip_count = self.redis.hget('queued_ip_count', item[1])
- if ip_count:
- ip_count = int(ip_count)
- if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
- print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
+ ip_count = self.get_ip_request_count(item[1])
+ _, simultaneous_ip = get_token_ratelimit(item[2])
+ if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
+ print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
return None # reject the request
- self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority})
- self.increment_ip_count(item[1], 'queued_ip_count')
+ timestamp = time.time()
+ event = DataEvent()
+ self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
return event
def get(self):
@@ -45,31 +54,59 @@ class RedisPriorityQueue:
data = self.redis.zpopmin('queue')
if data:
item = json.loads(data[0][0])
- client_ip = item[0][1]
- self.decrement_ip_count(client_ip, 'queued_ip_count')
return item
time.sleep(0.1) # wait for something to be added to the queue
- def increment_ip_count(self, client_ip: str, redis_key):
- self.redis.hincrby(redis_key, client_ip, 1)
-
- def decrement_ip_count(self, client_ip: str, redis_key):
- new_count = self.redis.hincrby(redis_key, client_ip, -1)
- if new_count <= 0:
- self.redis.hdel(redis_key, client_ip)
-
def __len__(self):
return self.redis.zcard('queue')
- def get_queued_ip_count(self, client_ip: str):
- q = self.redis.hget('queued_ip_count', client_ip)
- if not q:
- return 0
- return 0
+ def get_ip_request_count(self, client_ip: str):
+ """
+ Get the number of requests in the queue from a specific IP.
+ This is a bit inefficient since we iterate over the entire queue, but
+ keeps the queue as a single point of truth instead of tracking a separate hashed
+ set which can get confusing.
+ If we run into slowdowns in the future, we should go back to the hashed set approach.
+ :param client_ip:
+ :return:
+ """
+ start_time = time.time()
+ items = self.redis.zrange('queue', 0, -1)
+ count = 0
+ for item in items:
+ item_data = json.loads(item)
+ if item_data[0][1] == client_ip:
+ count += 1
+ elapsed_time = time.time() - start_time
+ if elapsed_time > 0.5:
+ raise Exception(f"!!! get_ip_request_count took {elapsed_time} seconds to execute !!!")
+ return count
+
+ def flush(self):
+ self.redis.flush()
+
+ def items(self):
+ return self.redis.zrange('queue', 0, -1)
+
+ def cleanup(self):
+ now = time.time()
+ for item in self.items():
+ item_data = json.loads(item)
+ timestamp = item_data[-2]
+ if now - timestamp > opts.backend_generate_request_timeout:
+ self.redis.zrem('queue', 0, item)
+ event_id = item_data[1]
+ event = DataEvent(event_id)
+ event.set((False, None, 'closed'))
+ print('Removed timed-out item from queue:', event_id)
class DataEvent:
- def __init__(self, event_id=None):
+ """
+ Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol).
+ """
+
+ def __init__(self, event_id: str = None):
self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub()
@@ -84,15 +121,89 @@ class DataEvent:
return pickle.loads(item['data'])
-priority_queue = RedisPriorityQueue()
+def update_active_workers(key: str, operation: str):
+ if operation == 'incr':
+ redis.incr(f'active_gen_workers:{key}')
+ elif operation == 'decr':
+ redis.decr(f'active_gen_workers:{key}')
+ if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0:
+ redis.set(f'active_gen_workers:{key}', 0)
-def incr_active_workers():
- redis.incr('active_gen_workers')
+def incr_active_workers(selected_model: str, backend_url: str):
+ update_active_workers(selected_model, 'incr')
+ update_active_workers(backend_url, 'incr')
-def decr_active_workers():
- redis.decr('active_gen_workers')
- new_count = redis.get('active_gen_workers', int, 0)
- if new_count < 0:
- redis.set('active_gen_workers', 0)
+def decr_active_workers(selected_model: str, backend_url: str):
+ update_active_workers(selected_model, 'decr')
+ update_active_workers(backend_url, 'decr')
+
+
+class PriorityQueue:
+ """
+ Helper class to wrangler all the different queues.
+ """
+
+ def __init__(self, backends: set = None):
+ """
+ Only have to load the backends once.
+ :param backends:
+ """
+ self.redis = Redis(host='localhost', port=6379, db=9)
+ if backends:
+ for item in backends:
+ self.redis.sadd('backends', item)
+
+ def get_backends(self):
+ return {x.decode('utf-8') for x in self.redis.smembers('backends')}
+
+ def get_queued_ip_count(self, client_ip: str):
+ count = 0
+ for backend_url in self.get_backends():
+ queue = RedisPriorityQueue(backend_url)
+ count += queue.get_ip_request_count(client_ip)
+ return count
+
+ def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
+ queue = RedisPriorityQueue(backend_url)
+ return queue.put(item, priority, selected_model, do_stream)
+
+ def activity(self):
+ lines = []
+ status_redis = RedisCustom('worker_status')
+ for worker in status_redis.keys():
+ lines.append((worker, status_redis.getp(worker)))
+ return sorted(lines)
+
+ def len(self, model_name):
+ count = 0
+ backends_with_models = set()
+ for k in self.get_backends():
+ info = cluster_config.get_backend(k)
+ if info.get('model') == model_name:
+ backends_with_models.add(k)
+ for backend_url in backends_with_models:
+ count += len(RedisPriorityQueue(backend_url))
+ return count
+
+ def __len__(self):
+ count = 0
+ p = set()
+ for backend_url in self.get_backends():
+ queue = RedisPriorityQueue(backend_url)
+ p.add((backend_url, len(queue)))
+ count += len(queue)
+ return count
+
+ def flush(self):
+ for k in self.redis.keys():
+ q = json.loads(self.redis.get(k))
+ q.flush()
+ self.redis.set(k, json.dumps(q))
+
+ def flush_db(self):
+ self.redis.flushdb()
+
+
+priority_queue = PriorityQueue()
diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py
index 4b1f640..f4abfa6 100644
--- a/llm_server/routes/request_handler.py
+++ b/llm_server/routes/request_handler.py
@@ -5,23 +5,22 @@ import flask
from flask import Response, request
from llm_server import opts
-from llm_server.database.conn import database
-from llm_server.database.database import log_prompt
+from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
+from llm_server.custom_redis import redis
+from llm_server.database.database import get_token_ratelimit
+from llm_server.database.log_to_db import log_to_db
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token
-from llm_server.routes.cache import redis
from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue
-DEFAULT_PRIORITY = 9999
-
class RequestHandler:
- def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
+ def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
self.request = incoming_request
- self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
+ # self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
# Routes need to validate it, here we just load it
if incoming_json:
@@ -34,11 +33,38 @@ class RequestHandler:
self.start_time = time.time()
self.client_ip = self.get_client_ip()
self.token = self.get_auth_token()
- self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
- self.backend = get_backend()
+ self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
self.parameters = None
self.used = False
- redis.zadd('recent_prompters', {self.client_ip: time.time()})
+
+ # This is null by default since most handlers need to transform the prompt in a specific way.
+ self.prompt = None
+
+ self.selected_model = selected_model
+ self.backend_url = get_a_cluster_backend(selected_model)
+ self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
+
+ # Debug stuff
+ # if not self.cluster_backend_info.get('mode'):
+ # print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
+ # if not self.cluster_backend_info.get('model'):
+ # print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
+ # if not self.cluster_backend_info.get('model_config'):
+ # print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
+
+ if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
+ self.offline = True
+ else:
+ self.offline = False
+ self.selected_model = self.cluster_backend_info['model']
+ self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
+ if self.token and not self.token.startswith('SYSTEM__'):
+ # "recent_prompters" is only used for stats.
+ redis.zadd('recent_prompters', {self.client_ip: time.time()})
+
+ def check_online(self) -> bool:
+ self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
+ return self.cluster_backend_info['online']
def get_auth_token(self):
if self.request_json_body.get('X-API-KEY'):
@@ -49,6 +75,8 @@ class RequestHandler:
return parse_token(self.request.headers['Authorization'])
def get_client_ip(self):
+ if self.request.headers.get('Llm-Connecting-Ip'):
+ return self.request.headers['Llm-Connecting-Ip']
if self.request.headers.get('X-Connecting-IP'):
return self.request.headers.get('X-Connecting-IP')
elif self.request.headers.get('Cf-Connecting-Ip'):
@@ -58,26 +86,7 @@ class RequestHandler:
else:
return self.request.remote_addr
- def get_token_ratelimit(self):
- priority = DEFAULT_PRIORITY
- simultaneous_ip = opts.simultaneous_requests_per_ip
- if self.token:
- cursor = database.cursor()
- try:
- cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
- result = cursor.fetchone()
- if result:
- priority, simultaneous_ip = result
- if simultaneous_ip is None:
- # No ratelimit for this token if null
- simultaneous_ip = 999999999
- finally:
- cursor.close()
- return priority, simultaneous_ip
-
def get_parameters(self):
- if self.request_json_body.get('max_tokens'):
- self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg
@@ -119,7 +128,7 @@ class RequestHandler:
backend_response = self.handle_error(combined_error_message, 'Validation Error')
if do_log:
- log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
+ log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
return False, backend_response
return True, (None, 0)
@@ -131,14 +140,18 @@ class RequestHandler:
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid:
return (False, None, None, 0), invalid_response
- event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
+ event = priority_queue.put(self.backend_url, (llm_request, self.client_ip, self.token, self.parameters), self.token_priority, self.selected_model)
else:
event = None
if not event:
return (False, None, None, 0), self.handle_ratelimited()
+ # TODO: add wait timeout
success, response, error_msg = event.wait()
+ if error_msg == 'closed':
+ return (False, None, None, 0), (self.handle_error('Request Timeout')[0], 408)
+
end_time = time.time()
elapsed_time = end_time - self.start_time
@@ -160,7 +173,17 @@ class RequestHandler:
else:
error_msg = error_msg.strip('.') + '.'
backend_response = self.handle_error(error_msg)
- log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
+ log_to_db(ip=self.client_ip,
+ token=self.token,
+ prompt=prompt,
+ response=backend_response[0].data.decode('utf-8'),
+ gen_time=None,
+ parameters=self.parameters,
+ headers=dict(self.request.headers),
+ backend_response_code=response_status_code,
+ request_url=self.request.url,
+ backend_url=self.backend_url,
+ is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
@@ -180,7 +203,7 @@ class RequestHandler:
if return_json_err:
error_msg = 'The backend did not return valid JSON.'
backend_response = self.handle_error(error_msg)
- log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
+ log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
@@ -189,22 +212,29 @@ class RequestHandler:
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def is_client_ratelimited(self) -> bool:
+ if self.token_priority == 0:
+ return False
+
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
x = redis.hget('processing_ips', self.client_ip)
if x:
processing_ip = int(x)
else:
processing_ip = 0
- if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
- return False
- else:
- print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
+
+ if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
+ print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
return True
+ else:
+ return False
def handle_request(self) -> Tuple[flask.Response, int]:
# Must include this in your child.
- # if self.used:
- # raise Exception('Can only use a RequestHandler object once.')
+ # assert not self.used
+ # if self.offline:
+ # msg = f'{self.selected_model} is not a valid model choice.'
+ # print(msg)
+ # return format_sillytavern_err(msg)
raise NotImplementedError
def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]:
@@ -214,11 +244,11 @@ class RequestHandler:
raise NotImplementedError
-def get_backend():
- if opts.mode == 'oobabooga':
- return OobaboogaBackend()
- elif opts.mode == 'vllm':
- return VLLMBackend()
+def get_backend_handler(mode, backend_url: str):
+ if mode == 'oobabooga':
+ return OobaboogaBackend(backend_url)
+ elif mode == 'vllm':
+ return VLLMBackend(backend_url)
else:
raise Exception
diff --git a/llm_server/routes/server_error.py b/llm_server/routes/server_error.py
index fec3836..a6d6f99 100644
--- a/llm_server/routes/server_error.py
+++ b/llm_server/routes/server_error.py
@@ -1,3 +1,3 @@
def handle_server_error(e):
- print(e)
- return {'error': True}, 500
+ print('Internal Error:', e)
+ return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500
diff --git a/llm_server/routes/stats.py b/llm_server/routes/stats.py
index a6e9e17..7f3b2fe 100644
--- a/llm_server/routes/stats.py
+++ b/llm_server/routes/stats.py
@@ -1,33 +1,11 @@
from datetime import datetime
-from llm_server.routes.cache import redis
-
-# proompters_5_min = 0
-# concurrent_semaphore = Semaphore(concurrent_gens)
+from llm_server.custom_redis import redis
+from llm_server.helpers import round_up_base
server_start_time = datetime.now()
-# TODO: do I need this?
-# def elapsed_times_cleanup():
-# global wait_in_queue_elapsed
-# while True:
-# current_time = time.time()
-# with wait_in_queue_elapsed_lock:
-# global wait_in_queue_elapsed
-# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
-# time.sleep(1)
-
-
-def calculate_avg_gen_time():
- # Get the average generation time from Redis
- average_generation_time = redis.get('average_generation_time')
- if average_generation_time is None:
- return 0
- else:
- return float(average_generation_time)
-
-
def get_total_proompts():
count = redis.get('proompts')
if count is None:
@@ -37,10 +15,27 @@ def get_total_proompts():
return count
-def get_active_gen_workers():
- active_gen_workers = redis.get('active_gen_workers')
- if active_gen_workers is None:
- count = 0
+def get_active_gen_workers_model(selected_model: str = None):
+ return redis.get(f'active_gen_workers:{selected_model}', dtype=int, default=0)
+
+
+def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
+ if active_gen_workers < concurrent_gens:
+ return 0
+ elif active_gen_workers >= concurrent_gens:
+ # Calculate how long it will take to complete the currently running gens and the queued requests.
+ # If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
+ # Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
+ # that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
+ # Add gen_time_calc to the time to account for the currently running generations.
+ # This assumes that all active workers will finish at the same time, which is unlikely.
+ # Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
+ proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
+ else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
+ return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
+ elif proompters_in_queue == 0 and active_gen_workers == 0:
+ # No queue, no workers
+ return 0
else:
- count = int(active_gen_workers)
- return count
+ # No queue
+ return gen_time_calc
diff --git a/llm_server/routes/v1/generate.py b/llm_server/routes/v1/generate.py
index 715288f..fcdc298 100644
--- a/llm_server/routes/v1/generate.py
+++ b/llm_server/routes/v1/generate.py
@@ -3,18 +3,18 @@ import traceback
from flask import jsonify, request
from . import bp
-from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
-@bp.route('/generate', methods=['POST'])
-def generate():
+@bp.route('/v1/generate', methods=['POST'])
+@bp.route('//v1/generate', methods=['POST'])
+def generate(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
- handler = OobaRequestHandler(request)
+ handler = OobaRequestHandler(request, selected_model=model_name)
try:
return handler.handle_request()
except Exception:
diff --git a/llm_server/routes/v1/generate_stats.py b/llm_server/routes/v1/generate_stats.py
index e144099..a9148b3 100644
--- a/llm_server/routes/v1/generate_stats.py
+++ b/llm_server/routes/v1/generate_stats.py
@@ -2,83 +2,30 @@ import time
from datetime import datetime
from llm_server import opts
+from llm_server.cluster.backend import get_model_choices
+from llm_server.cluster.cluster_config import cluster_config
+from llm_server.custom_redis import redis
from llm_server.database.database import get_distinct_ips_24h, sum_column
-from llm_server.helpers import deep_sort, round_up_base
-from llm_server.llm.info import get_running_model
-from llm_server.netdata import get_power_states
-from llm_server.routes.cache import redis
-from llm_server.routes.queue import priority_queue
-from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
+from llm_server.helpers import deep_sort
+from llm_server.routes.stats import get_total_proompts, server_start_time
-def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
- if active_gen_workers < concurrent_gens:
- return 0
- elif active_gen_workers >= concurrent_gens:
- # Calculate how long it will take to complete the currently running gens and the queued requests.
- # If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
- # Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
- # that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
- # Add gen_time_calc to the time to account for the currently running generations.
- # This assumes that all active workers will finish at the same time, which is unlikely.
- # Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
- proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
- else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
- return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
- elif proompters_in_queue == 0 and active_gen_workers == 0:
- # No queue, no workers
- return 0
- else:
- # No queue
- return gen_time_calc
-
-
-# TODO: have routes/__init__.py point to the latest API version generate_stats()
-
def generate_stats(regen: bool = False):
if not regen:
- c = redis.get('proxy_stats', dict)
+ c = redis.getp('proxy_stats')
if c:
return c
- model_name, error = get_running_model() # will return False when the fetch fails
- if isinstance(model_name, bool):
- online = False
- else:
- online = True
- redis.set('running_model', model_name)
+ model_choices, default_model = get_model_choices(regen=True)
- # t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
- # if len(t) == 0:
- # estimated_wait = 0
- # else:
- # waits = [elapsed for end, elapsed in t]
- # estimated_wait = int(sum(waits) / len(waits))
-
- active_gen_workers = get_active_gen_workers()
- proompters_in_queue = len(priority_queue)
-
- # This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
- # estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
-
- average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
- estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
-
- if opts.netdata_root:
- netdata_stats = {}
- power_states = get_power_states()
- for gpu, power_state in power_states.items():
- netdata_stats[gpu] = {
- 'power_state': power_state,
- # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
- }
- else:
- netdata_stats = {}
-
- base_client_api = redis.get('base_client_api', str)
+ base_client_api = redis.get('base_client_api', dtype=str)
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
output = {
+ 'models': {
+ 'choices': model_choices,
+ 'default': default_model,
+ },
'stats': {
'proompters': {
'5_min': proompters_5_min,
@@ -86,39 +33,49 @@ def generate_stats(regen: bool = False):
},
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
- 'average_generation_elapsed_sec': int(average_generation_time),
# 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
+ 'num_backends': len(cluster_config.all()) if opts.show_backends else None,
},
- 'online': online,
'endpoints': {
'blocking': f'https://{base_client_api}',
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
},
- 'queue': {
- 'processing': active_gen_workers,
- 'queued': proompters_in_queue,
- 'estimated_wait_sec': int(estimated_wait_sec),
- },
'timestamp': int(time.time()),
'config': {
'gatekeeper': 'none' if opts.auth_required is False else 'token',
- 'context_size': opts.context_size,
- 'concurrent': opts.concurrent_gens,
- 'model': opts.manual_model_name if opts.manual_model_name else model_name,
- 'mode': opts.mode,
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
+ 'api_mode': opts.frontend_api_mode
},
'keys': {
'openaiKeys': '∞',
'anthropicKeys': '∞',
},
- 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
- 'nvidia': netdata_stats
+ 'backends': {},
+ 'online': len(model_choices) > 0
}
+
+ # TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself
+
+ if opts.show_backends:
+ for backend_url, v in cluster_config.all().items():
+ backend_info = cluster_config.get_backend(backend_url)
+ if not backend_info['online']:
+ continue
+ backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None
+ output['backends'][backend_info['hash']] = {
+ 'uptime': backend_uptime,
+ 'max_tokens': backend_info['model_config'].get('max_position_embeddings', -1),
+ 'model': backend_info['model'],
+ 'mode': backend_info['mode'],
+ 'nvidia': backend_info['nvidia'],
+ 'priority': backend_info['priority'],
+ }
+
result = deep_sort(output)
# It may take a bit to get the base client API, so don't cache until then.
if base_client_api:
- redis.set_dict('proxy_stats', result) # Cache with no expiry
+ redis.setp('proxy_stats', result)
+
return result
diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py
index 0fc8f40..3ed2f58 100644
--- a/llm_server/routes/v1/generate_stream.py
+++ b/llm_server/routes/v1/generate_stream.py
@@ -1,186 +1,200 @@
import json
-import threading
import time
import traceback
-from typing import Union
+import ujson
from flask import request
+from redis import Redis
-from ..cache import redis
+from . import bp
from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
-from ..queue import decr_active_workers, decrement_ip_count, priority_queue
+from ..queue import priority_queue
from ... import opts
-from ...database.database import log_prompt
-from ...llm.generator import generator
-from ...llm.vllm import tokenize
-from ...stream import sock
+from ...custom_redis import redis
+from ...database.log_to_db import log_to_db
+from ...sock import sock
-# TODO: have workers process streaming requests
-# TODO: make sure to log the token as well (seems to be missing in the DB right now)
+# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
+# We solve this by splitting the routes
-@sock.route('/api/v1/stream')
-def stream(ws):
- def send_err_and_quit(quitting_err_msg):
- ws.send(json.dumps({
- 'event': 'text_stream',
- 'message_num': 0,
- 'text': quitting_err_msg
- }))
- ws.send(json.dumps({
- 'event': 'stream_end',
- 'message_num': 1
- }))
- ws.close()
- log_in_bg(quitting_err_msg, is_error=True)
+@bp.route('/v1/stream')
+@bp.route('//v1/stream')
+def stream(model_name=None):
+ return 'This is a websocket endpoint.', 400
- def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
- def background_task_exception():
- generated_tokens = tokenize(generated_text_bg)
- log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error)
+@sock.route('/v1/stream', bp=bp)
+def stream_without_model(ws):
+ do_stream(ws, model_name=None)
- # TODO: use async/await instead of threads
- thread = threading.Thread(target=background_task_exception)
- thread.start()
- thread.join()
- if not opts.enable_streaming:
- return 'Streaming is disabled', 401
+@sock.route('//v1/stream', bp=bp)
+def stream_with_model(ws, model_name=None):
+ do_stream(ws, model_name)
- r_headers = dict(request.headers)
- r_url = request.url
- message_num = 0
- while ws.connected:
- message = ws.receive()
- request_valid_json, request_json_body = validate_json(message)
- if not request_valid_json or not request_json_body.get('prompt'):
- return 'Invalid JSON', 400
- else:
- if opts.mode != 'vllm':
- # TODO: implement other backends
- raise NotImplementedError
- auth_failure = require_api_key(request_json_body)
- if auth_failure:
- return auth_failure
+def do_stream(ws, model_name):
+ event_id = None
+ try:
+ def send_err_and_quit(quitting_err_msg):
+ ws.send(json.dumps({
+ 'event': 'text_stream',
+ 'message_num': 0,
+ 'text': quitting_err_msg
+ }))
+ ws.send(json.dumps({
+ 'event': 'stream_end',
+ 'message_num': 1
+ }))
+ ws.close()
+ log_to_db(ip=handler.client_ip,
+ token=handler.token,
+ prompt=input_prompt,
+ response=quitting_err_msg,
+ gen_time=None,
+ parameters=handler.parameters,
+ headers=r_headers,
+ backend_response_code=response_status_code,
+ request_url=r_url,
+ backend_url=handler.backend_url,
+ response_tokens=None,
+ is_error=True
+ )
- handler = OobaRequestHandler(request, request_json_body)
- generated_text = ''
- input_prompt = request_json_body['prompt']
- response_status_code = 0
- start_time = time.time()
+ if not opts.enable_streaming:
+ return 'Streaming disabled', 403
- err_msg = None
- if handler.is_client_ratelimited():
- r, _ = handler.handle_ratelimited(do_log=False)
- err_msg = r.json['results'][0]['text']
+ r_headers = dict(request.headers)
+ r_url = request.url
+ message_num = 0
+
+ while ws.connected:
+ message = ws.receive()
+ request_valid_json, request_json_body = validate_json(message)
+
+ if not request_valid_json or not request_json_body.get('prompt'):
+ return 'Invalid JSON', 400
else:
- request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
- if not request_valid:
- err_msg = invalid_response[0].json['results'][0]['text']
- if err_msg:
- send_err_and_quit(err_msg)
- return
+ # We have to do auth ourselves since the details are sent in the message.
+ auth_failure = require_api_key(request_json_body)
+ if auth_failure:
+ return auth_failure
- llm_request = {
- **handler.parameters,
- 'prompt': input_prompt,
- 'stream': True,
- }
-
- # Add a dummy event to the queue and wait for it to reach a worker
- event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority)
- if not event:
- r, _ = handler.handle_ratelimited()
- err_msg = r.json['results'][0]['text']
- send_err_and_quit(err_msg)
- return
- try:
- response = generator(llm_request)
- if not response:
- error_msg = 'Failed to reach backend while streaming.'
- print('Streaming failed:', error_msg)
- msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
+ handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
+ if handler.offline:
+ msg = f'{handler.selected_model} is not a valid model choice.'
+ print(msg)
ws.send(json.dumps({
'event': 'text_stream',
- 'message_num': message_num,
+ 'message_num': 0,
'text': msg
}))
+ return
+
+ if handler.cluster_backend_info['mode'] != 'vllm':
+ # TODO: implement other backends
+ raise NotImplementedError
+
+ input_prompt = request_json_body['prompt']
+ response_status_code = 0
+ start_time = time.time()
+
+ err_msg = None
+ if handler.is_client_ratelimited():
+ r, _ = handler.handle_ratelimited(do_log=False)
+ err_msg = r.json['results'][0]['text']
else:
- # Be extra careful when getting attributes from the response object
- try:
- response_status_code = response.status_code
- except:
- response_status_code = 0
+ request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
+ if not request_valid:
+ err_msg = invalid_response[0].json['results'][0]['text']
+ if err_msg:
+ send_err_and_quit(err_msg)
+ return
- partial_response = b''
+ handler.parameters, _ = handler.get_parameters()
+ handler.prompt = input_prompt
+ handler.request_json_body = {
+ 'prompt': handler.prompt,
+ **handler.parameters
+ }
- for chunk in response.iter_content(chunk_size=1):
- partial_response += chunk
- if partial_response.endswith(b'\x00'):
- json_strs = partial_response.split(b'\x00')
- for json_str in json_strs:
- if json_str:
- try:
- json_obj = json.loads(json_str.decode())
- new = json_obj['text'][0].split(input_prompt + generated_text)[1]
- generated_text = generated_text + new
- except IndexError:
- # ????
- continue
- try:
- ws.send(json.dumps({
- 'event': 'text_stream',
- 'message_num': message_num,
- 'text': new
- }))
- except:
- # The has client closed the stream.
- if request:
- request.close()
- ws.close()
- end_time = time.time()
- elapsed_time = end_time - start_time
- log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
- return
+ event = None
+ if not handler.is_client_ratelimited():
+ event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
+ if not event:
+ r = handler.handle_ratelimited()
+ send_err_and_quit(r[0].data)
+ return
+ event_id = event.event_id
+ _, stream_name, error_msg = event.wait()
+ if error_msg:
+ print('Stream failed to start streaming:', error_msg)
+ ws.close(reason=1014, message='Request Timeout')
+ return
+
+ stream_redis = Redis(db=8)
+ generated_text = ''
+
+ try:
+ last_id = '0-0' # The ID of the last entry we read.
+ while True:
+ stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
+ if not stream_data:
+ print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
+ return
+ else:
+ for stream_index, item in stream_data[0][1]:
+ last_id = stream_index
+ data = ujson.loads(item[b'data'])
+ if data['error']:
+ print(data['error'])
+ send_err_and_quit('Encountered exception while streaming.')
+ return
+ elif data['new']:
+ ws.send(json.dumps({
+ 'event': 'text_stream',
+ 'message_num': message_num,
+ 'text': data['new']
+ }))
message_num += 1
- partial_response = b'' # Reset the partial response
-
- # If there is no more data, break the loop
- if not chunk:
- break
-
- end_time = time.time()
- elapsed_time = end_time - start_time
- log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
- except:
- traceback.print_exc()
- generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
- ws.send(json.dumps({
- 'event': 'text_stream',
- 'message_num': message_num,
- 'text': generated_text
- }))
- if request:
- request.close()
- ws.close()
- log_in_bg(generated_text, is_error=True, status_code=response_status_code)
- return
- finally:
- # The worker incremented it, we'll decrement it.
- decrement_ip_count(handler.client_ip, 'processing_ips')
- decr_active_workers()
- try:
- ws.send(json.dumps({
- 'event': 'stream_end',
- 'message_num': message_num
- }))
- except:
- # The client closed the stream.
- end_time = time.time()
- elapsed_time = end_time - start_time
- log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
- ws.close() # this is important if we encountered and error and exited early.
+ generated_text = generated_text + data['new']
+ elif data['completed']:
+ return
+ except:
+ send_err_and_quit('Encountered exception while streaming.')
+ traceback.print_exc()
+ finally:
+ try:
+ ws.send(json.dumps({
+ 'event': 'stream_end',
+ 'message_num': message_num
+ }))
+ except:
+ # The client closed the stream.
+ pass
+ if stream_name:
+ stream_redis.delete(stream_name)
+ end_time = time.time()
+ elapsed_time = end_time - start_time
+ log_to_db(ip=handler.client_ip,
+ token=handler.token,
+ prompt=input_prompt,
+ response=generated_text,
+ gen_time=elapsed_time,
+ parameters=handler.parameters,
+ headers=r_headers,
+ backend_response_code=response_status_code,
+ request_url=r_url,
+ backend_url=handler.backend_url
+ )
+ finally:
+ if event_id:
+ redis.publish(f'notifications:{event_id}', 'canceled')
+ try:
+ # Must close the connection or greenlets will complain.
+ ws.close()
+ except:
+ pass
diff --git a/llm_server/routes/v1/info.py b/llm_server/routes/v1/info.py
index 2091118..342921e 100644
--- a/llm_server/routes/v1/info.py
+++ b/llm_server/routes/v1/info.py
@@ -2,22 +2,16 @@ import time
from flask import jsonify, request
+from llm_server.custom_redis import flask_cache
from . import bp
-from ..auth import requires_auth
-from ..cache import flask_cache
from ... import opts
-from ...llm.info import get_running_model
+from ...cluster.backend import get_backends_from_model, is_valid_model
+from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
-# @bp.route('/info', methods=['GET'])
-# # @cache.cached(timeout=3600, query_string=True)
-# def get_info():
-# # requests.get()
-# return 'yes'
-
-
-@bp.route('/model', methods=['GET'])
-def get_model():
+@bp.route('/v1/model', methods=['GET'])
+@bp.route('//v1/model', methods=['GET'])
+def get_model(model_name=None):
# We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url
@@ -26,24 +20,21 @@ def get_model():
if cached_response:
return cached_response
- model_name, error = get_running_model()
if not model_name:
+ model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
+
+ if not is_valid_model(model_name):
response = jsonify({
- 'code': 502,
- 'msg': 'failed to reach backend',
- 'type': error.__class__.__name__
- }), 500 # return 500 so Cloudflare doesn't intercept us
+ 'code': 400,
+ 'msg': 'Model does not exist.',
+ }), 400
else:
+ num_backends = len(get_backends_from_model(model_name))
response = jsonify({
'result': opts.manual_model_name if opts.manual_model_name else model_name,
+ 'model_backend_count': num_backends,
'timestamp': int(time.time())
}), 200
flask_cache.set(cache_key, response, timeout=60)
return response
-
-
-@bp.route('/backend', methods=['GET'])
-@requires_auth
-def get_backend():
- return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200
diff --git a/llm_server/routes/v1/proxy.py b/llm_server/routes/v1/proxy.py
index 4349526..6e3708e 100644
--- a/llm_server/routes/v1/proxy.py
+++ b/llm_server/routes/v1/proxy.py
@@ -1,8 +1,10 @@
from flask import jsonify
+from llm_server.custom_redis import flask_cache
from . import bp
from .generate_stats import generate_stats
-from ..cache import flask_cache
+from ..auth import requires_auth
+from ...cluster.cluster_config import cluster_config, get_backends
from ...helpers import jsonify_pretty
@@ -10,3 +12,14 @@ from ...helpers import jsonify_pretty
@flask_cache.cached(timeout=5, query_string=True)
def get_stats():
return jsonify_pretty(generate_stats())
+
+
+@bp.route('/backends', methods=['GET'])
+@requires_auth
+def get_backend():
+ online, offline = get_backends()
+ result = {}
+ for i in online + offline:
+ info = cluster_config.get_backend(i)
+ result[info['hash']] = info
+ return jsonify(result), 200
diff --git a/llm_server/stream.py b/llm_server/sock.py
similarity index 77%
rename from llm_server/stream.py
rename to llm_server/sock.py
index 8ac2fc1..2f1a17d 100644
--- a/llm_server/stream.py
+++ b/llm_server/sock.py
@@ -3,6 +3,6 @@ from flask_sock import Sock
sock = Sock()
-def init_socketio(app):
+def init_wssocket(app):
global sock
sock.init_app(app)
diff --git a/llm_server/workers/app.py b/llm_server/workers/app.py
deleted file mode 100644
index fda6fb3..0000000
--- a/llm_server/workers/app.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from threading import Thread
-
-from .blocking import start_workers
-from .main import main_background_thread
-from .moderator import start_moderation_workers
-from .printer import console_printer
-from .recent import recent_prompters_thread
-from .threads import cache_stats
-from .. import opts
-
-
-def start_background():
- start_workers(opts.concurrent_gens)
-
- t = Thread(target=main_background_thread)
- t.daemon = True
- t.start()
- print('Started the main background thread.')
-
- start_moderation_workers(opts.openai_moderation_workers)
-
- t = Thread(target=cache_stats)
- t.daemon = True
- t.start()
- print('Started the stats cacher.')
-
- t = Thread(target=recent_prompters_thread)
- t.daemon = True
- t.start()
- print('Started the recent proompters thread.')
-
- t = Thread(target=console_printer)
- t.daemon = True
- t.start()
- print('Started the console printer.')
diff --git a/llm_server/workers/blocking.py b/llm_server/workers/blocking.py
deleted file mode 100644
index 27b0815..0000000
--- a/llm_server/workers/blocking.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import threading
-import time
-
-from llm_server import opts
-from llm_server.llm.generator import generator
-from llm_server.routes.cache import redis
-from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
-
-
-def worker():
- while True:
- need_to_wait()
- (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
- need_to_wait()
-
- increment_ip_count(client_ip, 'processing_ips')
- incr_active_workers()
-
- if not request_json_body:
- # This was a dummy request from the websocket handler.
- # We're going to let the websocket handler decrement processing_ips and active_gen_workers.
- continue
-
- try:
- success, response, error_msg = generator(request_json_body)
- event = DataEvent(event_id)
- event.set((success, response, error_msg))
- finally:
- decrement_ip_count(client_ip, 'processing_ips')
- decr_active_workers()
-
-
-def start_workers(num_workers: int):
- i = 0
- for _ in range(num_workers):
- t = threading.Thread(target=worker)
- t.daemon = True
- t.start()
- i += 1
- print(f'Started {i} inference workers.')
-
-
-def need_to_wait():
- # We need to check the number of active workers since the streaming endpoint may be doing something.
- active_workers = redis.get('active_gen_workers', int, 0)
- s = time.time()
- while active_workers >= opts.concurrent_gens:
- time.sleep(0.01)
- e = time.time()
- if e - s > 0.5:
- print(f'Worker was delayed {e - s} seconds.')
diff --git a/llm_server/workers/cleaner.py b/llm_server/workers/cleaner.py
new file mode 100644
index 0000000..95a6a78
--- /dev/null
+++ b/llm_server/workers/cleaner.py
@@ -0,0 +1,32 @@
+import time
+
+from redis import Redis
+
+from llm_server.workers.inferencer import STREAM_NAME_PREFIX
+
+
+# NOT NEEDED
+
+def cleaner():
+ r = Redis(db=8)
+ stream_info = {}
+
+ while True:
+ all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
+ processed_streams = []
+ for stream in all_streams:
+ stream = stream.decode()
+ current_size = r.xlen(stream)
+
+ # If the stream is new or its size has changed, update the size and time in the dictionary
+ if stream not in stream_info or current_size != stream_info[stream]['size']:
+ stream_info[stream] = {'size': current_size, 'time': time.time()}
+ processed_streams.append(stream)
+ else:
+ # If the size hasn't changed for 5 minutes, delete the stream
+ if time.time() - stream_info[stream]['time'] >= 300:
+ r.delete(stream)
+ print(f"Stream '{stream}' deleted due to inactivity.")
+ del stream_info[stream]
+
+ time.sleep(60)
diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py
new file mode 100644
index 0000000..21e45d0
--- /dev/null
+++ b/llm_server/workers/inferencer.py
@@ -0,0 +1,160 @@
+import json
+import threading
+import time
+import traceback
+from uuid import uuid4
+
+import ujson
+from redis import Redis
+
+from llm_server.cluster.cluster_config import cluster_config
+from llm_server.custom_redis import RedisCustom, redis
+from llm_server.llm.generator import generator
+from llm_server.logging import create_logger
+from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
+
+stream_redis = Redis(db=8)
+
+STREAM_NAME_PREFIX = 'stream'
+
+
+def check_cancellation(event, event_id):
+ """
+ This thread checks the pub/sub channel in the background so the main process
+ isn't bogged down with Redis calls. Otherwise, the main process slows down to 1 token/sec.
+ :param event:
+ :param event_id:
+ :return:
+ """
+ pubsub = redis.pubsub()
+ pubsub.subscribe(f'notifications:{event_id}')
+ while not event.is_set():
+ message = pubsub.get_message()
+ if message and message['data'] == b'canceled':
+ event.set()
+ time.sleep(0.5) # check every half second
+
+
+def get_stream_name(name: str):
+ return f'{STREAM_NAME_PREFIX}:{name}'
+
+
+def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
+ logger = create_logger('inferencer')
+ prompt = msg_to_backend['prompt']
+ stream_name = get_stream_name(stream_name)
+ stream_redis.delete(get_stream_name(stream_name)) # be extra sure
+ event = threading.Event()
+ threading.Thread(target=check_cancellation, args=(event, event_id)).start()
+ try:
+ response = generator(msg_to_backend, backend_url)
+ generated_text = ''
+ partial_response = b''
+ for chunk in response.iter_content(chunk_size=1):
+ # If there is no more data, break the loop
+ if not chunk:
+ break
+ if event.is_set():
+ logger.debug('Client canceled generation')
+ response.close()
+ return
+
+ partial_response += chunk
+ if partial_response.endswith(b'\x00'):
+ json_strs = partial_response.split(b'\x00')
+ for json_str in json_strs:
+ if json_str:
+ try:
+ json_obj = json.loads(json_str.decode())
+ new = json_obj['text'][0].split(prompt + generated_text)[1]
+ generated_text = generated_text + new
+ except IndexError:
+ # ????
+ continue
+ stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})})
+ except AttributeError as e:
+ if str(e) == "'bool' object has no attribute 'iter_content'":
+ # We don't care about these errors.
+ logger.debug('failed to stream from backend - no response')
+ else:
+ raise
+ except Exception as e:
+ stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
+ raise # We won't handle the exception here.
+ finally:
+ # Publish final message to Redis stream
+ stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})})
+ event.set() # stop the cancellation checking thread
+
+
+#
+def worker(backend_url):
+ logger = create_logger('inferencer')
+ status_redis = RedisCustom('worker_status')
+ worker_id = str(uuid4())
+ status_redis.setp(str(worker_id), None)
+ redis_queue = RedisPriorityQueue(backend_url)
+ while True:
+ status_redis.setp(str(worker_id), 'waiting...')
+ (request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
+ event = DataEvent(event_id)
+
+ try:
+ backend_info = cluster_config.get_backend(backend_url)
+ except:
+ # This is not a critical error because it usually means that the backend is
+ # offline and this backend is in a state of transition from online to offline.
+ logger.debug(f'got an exception while getting info for backend {backend_url} - ', traceback.format_exc())
+ event.set((False, None, 'exception'))
+ continue
+
+ if not backend_info['online']:
+ event.set((False, None, 'canceled'))
+ continue
+
+ if not selected_model:
+ selected_model = backend_info['model']
+
+ logger.debug(f"Starting using {backend_url} and {selected_model}. Online: {backend_info['online']}. Streaming: {do_stream}")
+
+ try:
+ stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
+ increment_ip_count(client_ip, 'processing_ips')
+ incr_active_workers(selected_model, backend_url)
+
+ if do_stream:
+ status_redis.setp(str(worker_id), ('streaming', client_ip))
+
+ # Return the name of the stream that the slave should connect to.
+ event.set((True, get_stream_name(worker_id), None))
+
+ msg_to_backend = {
+ **parameters,
+ 'prompt': request_json_body['prompt'],
+ 'stream': True,
+ }
+ inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
+ else:
+ # Normal inference (not streaming).
+ status_redis.setp(str(worker_id), ('generating', client_ip))
+ success, response, error_msg = generator(request_json_body, backend_url)
+ event.set((success, response, error_msg))
+ except:
+ logger.error(traceback.format_exc())
+ event.set((False, None, 'exception'))
+ finally:
+ decrement_ip_count(client_ip, 'processing_ips')
+ decr_active_workers(selected_model, backend_url)
+ status_redis.setp(str(worker_id), None)
+
+
+def start_workers(cluster: dict):
+ logger = create_logger('inferencer')
+ i = 0
+ for item in cluster:
+ for _ in range(item['concurrent_gens']):
+ t = threading.Thread(target=worker, args=(item['backend_url'],))
+ t.daemon = True
+ t.start()
+ i += 1
+ logger.info(f'Started {i} inference workers.')
diff --git a/llm_server/workers/logger.py b/llm_server/workers/logger.py
new file mode 100644
index 0000000..eada969
--- /dev/null
+++ b/llm_server/workers/logger.py
@@ -0,0 +1,31 @@
+import pickle
+import traceback
+
+import redis
+
+from llm_server.database.database import do_db_log
+
+
+def db_logger():
+ """
+ We don't want the logging operation to be blocking, so we will use a background worker
+ to do the logging.
+ :return:
+ """
+
+ r = redis.Redis(host='localhost', port=6379, db=3)
+ p = r.pubsub()
+ p.subscribe('database-logger')
+
+ for message in p.listen():
+ try:
+ if message['type'] == 'message':
+ data = pickle.loads(message['data'])
+ function_name = data['function']
+ args = data['args']
+ kwargs = data['kwargs']
+
+ if function_name == 'log_prompt':
+ do_db_log(*args, **kwargs)
+ except:
+ traceback.print_exc()
diff --git a/llm_server/workers/main.py b/llm_server/workers/main.py
deleted file mode 100644
index 747f699..0000000
--- a/llm_server/workers/main.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import time
-from threading import Thread
-
-from llm_server import opts
-from llm_server.database.database import weighted_average_column_for_model
-from llm_server.llm.info import get_running_model
-from llm_server.routes.cache import redis
-
-
-def main_background_thread():
- redis.set('average_generation_elapsed_sec', 0)
- redis.set('estimated_avg_tps', 0)
- redis.set('average_output_tokens', 0)
- redis.set('backend_online', 0)
- redis.set_dict('backend_info', {})
-
- while True:
- # TODO: unify this
- if opts.mode == 'oobabooga':
- running_model, err = get_running_model()
- if err:
- print(err)
- redis.set('backend_online', 0)
- else:
- redis.set('running_model', running_model)
- redis.set('backend_online', 1)
- elif opts.mode == 'vllm':
- running_model, err = get_running_model()
- if err:
- print(err)
- redis.set('backend_online', 0)
- else:
- redis.set('running_model', running_model)
- redis.set('backend_online', 1)
- else:
- raise Exception
-
- # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
- # was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
- average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
- if average_generation_elapsed_sec: # returns None on exception
- redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
-
- # overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
- # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
-
- average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
- if average_generation_elapsed_sec:
- redis.set('average_output_tokens', average_output_tokens)
-
- # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
- # print(f'Weighted: {average_output_tokens}, overall: {overall}')
-
- estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
- redis.set('estimated_avg_tps', estimated_avg_tps)
- time.sleep(60)
diff --git a/llm_server/workers/mainer.py b/llm_server/workers/mainer.py
new file mode 100644
index 0000000..d342f4b
--- /dev/null
+++ b/llm_server/workers/mainer.py
@@ -0,0 +1,57 @@
+import time
+
+import requests
+
+from llm_server import opts
+from llm_server.cluster.cluster_config import cluster_config, get_backends
+from llm_server.custom_redis import redis
+from llm_server.database.database import weighted_average_column_for_model
+from llm_server.llm.info import get_info
+from llm_server.routes.queue import RedisPriorityQueue, priority_queue
+
+
+def main_background_thread():
+ while True:
+ online, offline = get_backends()
+ for backend_url in online:
+ backend_info = cluster_config.get_backend(backend_url)
+ backend_mode = backend_info['mode']
+ backend_info = get_info(backend_url, backend_mode)
+ running_model = backend_info.get('model')
+ if not running_model:
+ continue
+
+ average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode)
+ if average_generation_elapsed_sec: # returns None on exception
+ cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec)
+ if average_output_tokens:
+ cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens)
+ if average_generation_elapsed_sec and average_output_tokens:
+ cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
+
+ if opts.background_homepage_cacher:
+ try:
+ base_client_api = redis.get('base_client_api', dtype=str)
+ r = requests.get('https://' + base_client_api, timeout=5)
+ except Exception as e:
+ print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
+
+ backends = priority_queue.get_backends()
+ for backend_url in backends:
+ queue = RedisPriorityQueue(backend_url)
+ queue.cleanup()
+
+ time.sleep(30)
+
+
+def calc_stats_for_backend(backend_url, running_model, backend_mode):
+ # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
+ # was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
+ average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
+ running_model, backend_mode, backend_url, exclude_zeros=True,
+ include_system_tokens=opts.include_system_tokens_in_stats) or 0
+ average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
+ running_model, backend_mode, backend_url, exclude_zeros=True,
+ include_system_tokens=opts.include_system_tokens_in_stats) or 0
+ estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
+ return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps
diff --git a/llm_server/workers/moderator.py b/llm_server/workers/moderator.py
index 4457d05..6d56eee 100644
--- a/llm_server/workers/moderator.py
+++ b/llm_server/workers/moderator.py
@@ -1,10 +1,13 @@
import json
import threading
+import time
import traceback
import redis as redis_redis
+from llm_server import opts
from llm_server.llm.openai.moderation import check_moderation_endpoint
+from llm_server.logging import create_logger
redis_moderation = redis_redis.Redis()
@@ -16,36 +19,43 @@ def start_moderation_workers(num_workers):
t.daemon = True
t.start()
i += 1
- print(f'Started {i} moderation workers.')
-def moderation_worker():
- while True:
- result = redis_moderation.blpop('queue:msgs_to_check')
- try:
- msg, tag = json.loads(result[1])
- _, categories = check_moderation_endpoint(msg)
- redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
- except:
- print(result)
- traceback.print_exc()
- continue
-
-
-def add_moderation_task(msg, tag):
- redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
-
+# TODO: don't use UUID tags to identify items. Use native redis
def get_results(tag, num_tasks):
- tag = str(tag) # Required for comparison with Redis results.
+ tag = str(tag) # Cast a UUID4 to a string.
flagged_categories = set()
num_results = 0
+ start_time = time.time()
while num_results < num_tasks:
- result = redis_moderation.blpop('queue:flagged_categories')
+ result = redis_moderation.blpop(['queue:flagged_categories'], timeout=opts.openai_moderation_timeout)
+ if result is None:
+ break # Timeout occurred, break the loop.
result_tag, categories = json.loads(result[1])
if result_tag == tag:
if categories:
for item in categories:
flagged_categories.add(item)
num_results += 1
+ if time.time() - start_time > opts.openai_moderation_timeout:
+ logger.warning('Timed out waiting for result from moderator')
+ break
return list(flagged_categories)
+
+
+def moderation_worker():
+ logger = create_logger('moderator')
+ while True:
+ result = redis_moderation.blpop(['queue:msgs_to_check'])
+ try:
+ msg, tag = json.loads(result[1])
+ _, categories = check_moderation_endpoint(msg)
+ redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
+ except:
+ logger.error(traceback.format_exc())
+ continue
+
+
+def add_moderation_task(msg, tag):
+ redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py
index cb0f032..deb3246 100644
--- a/llm_server/workers/printer.py
+++ b/llm_server/workers/printer.py
@@ -1,25 +1,34 @@
-import logging
import time
+import traceback
-from llm_server.routes.cache import redis
+from llm_server.cluster.backend import get_running_models
+from llm_server.cluster.cluster_config import cluster_config
+from llm_server.custom_redis import redis
+from llm_server.logging import create_logger
from llm_server.routes.queue import priority_queue
-logger = logging.getLogger('console_printer')
-if not logger.handlers:
- handler = logging.StreamHandler()
- handler.setLevel(logging.INFO)
- logger.setLevel(logging.INFO)
- formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s")
- handler.setFormatter(formatter)
- logger.addHandler(handler)
-
def console_printer():
+ logger = create_logger('console_printer')
time.sleep(3)
while True:
- processing = redis.hkeys('processing_ips')
- processing_count = 0
- for ip in processing:
- processing_count += int(redis.hget('processing_ips', ip))
- logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}')
+ try:
+ processing = redis.keys('active_gen_workers:http*') # backends always start with http
+ processing_count = 0
+ if len(processing):
+ for k in processing:
+ processing_count += redis.get(k, default=0, dtype=int)
+ backends = [k for k, v in cluster_config.all().items() if v['online']]
+ activity = priority_queue.activity()
+
+ # Calculate the queue size the same way it's done on the stats.
+ queue_size = 0
+ running_models = get_running_models()
+ for model in running_models:
+ queue_size += priority_queue.len(model)
+
+ # Active Workers and Processing should read the same. If not, that's an issue.
+ logger.info(f'Active Workers: {len([i for i in activity if (i[1] and i[1] != "waiting...")])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {len(backends)}')
+ except:
+ logger.error(traceback.format_exc())
time.sleep(10)
diff --git a/llm_server/workers/recent.py b/llm_server/workers/recenter.py
similarity index 81%
rename from llm_server/workers/recent.py
rename to llm_server/workers/recenter.py
index d378a87..c6158d6 100644
--- a/llm_server/workers/recent.py
+++ b/llm_server/workers/recenter.py
@@ -1,6 +1,6 @@
import time
-from llm_server.routes.cache import redis
+from llm_server.custom_redis import redis
def recent_prompters_thread():
diff --git a/llm_server/workers/threader.py b/llm_server/workers/threader.py
new file mode 100644
index 0000000..542a630
--- /dev/null
+++ b/llm_server/workers/threader.py
@@ -0,0 +1,58 @@
+import time
+from threading import Thread
+
+from llm_server import opts
+from llm_server.cluster.worker import cluster_worker
+from llm_server.logging import create_logger
+from llm_server.routes.v1.generate_stats import generate_stats
+from llm_server.workers.inferencer import start_workers
+from llm_server.workers.logger import db_logger
+from llm_server.workers.mainer import main_background_thread
+from llm_server.workers.moderator import start_moderation_workers
+from llm_server.workers.printer import console_printer
+from llm_server.workers.recenter import recent_prompters_thread
+
+
+def cache_stats():
+ while True:
+ generate_stats(regen=True)
+ time.sleep(5)
+
+
+def start_background():
+ logger = create_logger('threader')
+ start_workers(opts.cluster)
+
+ t = Thread(target=main_background_thread)
+ t.daemon = True
+ t.start()
+ logger.info('Started the main background thread.')
+
+ num_moderators = opts.cluster_workers * 3
+ start_moderation_workers(num_moderators)
+ logger.info(f'Started {num_moderators} moderation workers.')
+
+ t = Thread(target=cache_stats)
+ t.daemon = True
+ t.start()
+ logger.info('Started the stats cacher.')
+
+ t = Thread(target=recent_prompters_thread)
+ t.daemon = True
+ t.start()
+ logger.info('Started the recent proompters thread.')
+
+ t = Thread(target=console_printer)
+ t.daemon = True
+ t.start()
+ logger.info('Started the console logger.infoer.')
+
+ t = Thread(target=cluster_worker)
+ t.daemon = True
+ t.start()
+ logger.info('Started the cluster worker.')
+
+ t = Thread(target=db_logger)
+ t.daemon = True
+ t.start()
+ logger.info('Started background logger.')
diff --git a/llm_server/workers/threads.py b/llm_server/workers/threads.py
deleted file mode 100644
index d1c5183..0000000
--- a/llm_server/workers/threads.py
+++ /dev/null
@@ -1,9 +0,0 @@
-import time
-
-from llm_server.routes.v1.generate_stats import generate_stats
-
-
-def cache_stats():
- while True:
- generate_stats(regen=True)
- time.sleep(5)
diff --git a/other/gradio/gradio_chat.py b/other/gradio/gradio_chat.py
new file mode 100644
index 0000000..179e748
--- /dev/null
+++ b/other/gradio/gradio_chat.py
@@ -0,0 +1,103 @@
+import os
+import sys
+import time
+import traceback
+import warnings
+from threading import Thread
+
+import gradio as gr
+import openai
+import requests
+
+warnings.filterwarnings("ignore")
+
+API_BASE = os.getenv('API_BASE')
+if not API_BASE:
+ print('Must set the secret variable API_BASE to your https://your-site/api')
+ sys.exit(1)
+API_BASE = API_BASE.strip('/')
+
+APP_TITLE = os.getenv('APP_TITLE')
+PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE')
+TRACKING_CODE = os.getenv('TRACKING_CODE')
+
+
+def background():
+ while True:
+ previous = openai.api_base
+ try:
+ r = requests.get(API_BASE + '/stats').json()
+ if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys():
+ openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1'
+ else:
+ openai.api_base = API_BASE + '/openai/v1'
+ except:
+ traceback.print_exc()
+ openai.api_base = API_BASE + '/openai/v1'
+ if openai.api_base != previous:
+ print('Set primary model to', openai.api_base)
+ time.sleep(10)
+
+
+if PRIMARY_MODEL_CHOICE:
+ t = Thread(target=background)
+ t.daemon = True
+ t.start()
+ print('Started the background thread.')
+
+# A system prompt can be injected into the very first spot in the context.
+# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE,
+# the content in CONTEXT_TRIGGER_INJECTION will be injected.
+# Setting CONTEXT_TRIGGER_PHRASE will also add it to the selectable examples.
+CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE')
+CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION')
+
+openai.api_key = 'null'
+openai.api_base = API_BASE + '/openai/v1'
+
+
+def stream_response(prompt, history):
+ messages = []
+ do_injection = False
+ for human, assistant in history:
+ messages.append({'role': 'user', 'content': str(human)})
+ messages.append({'role': 'assistant', 'content': str(assistant)})
+
+ if CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in human:
+ do_injection = True
+ messages.append({'role': 'user', 'content': prompt})
+
+ if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt):
+ messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION})
+
+ try:
+ response = openai.ChatCompletion.create(
+ model='0',
+ messages=messages,
+ temperature=0,
+ max_tokens=300,
+ stream=True,
+ headers={'LLM-Source': 'huggingface-demo'}
+ )
+ except Exception:
+ raise gr.Error("Failed to reach inference endpoint.")
+
+ message = ''
+ for chunk in response:
+ if len(chunk['choices'][0]['delta']) != 0:
+ message += chunk['choices'][0]['delta']['content']
+ yield message
+
+
+examples = ["hello"]
+if CONTEXT_TRIGGER_PHRASE:
+ examples.insert(0, CONTEXT_TRIGGER_PHRASE)
+
+with gr.Blocks(analytics_enabled=False) as demo:
+ gr.ChatInterface(stream_response, examples=examples, title=APP_TITLE, analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}')
+
+ if TRACKING_CODE:
+ print('Inserting tracking code')
+ gr.HTML(TRACKING_CODE)
+
+demo.queue(concurrency_count=1, api_open=False).launch(show_api=False)
diff --git a/other/gradio/requirements.txt b/other/gradio/requirements.txt
new file mode 100644
index 0000000..eb4baac
--- /dev/null
+++ b/other/gradio/requirements.txt
@@ -0,0 +1,3 @@
+gradio
+openai
+requests
\ No newline at end of file
diff --git a/gunicorn.py b/other/gunicorn.py
similarity index 60%
rename from gunicorn.py
rename to other/gunicorn.py
index 30f9274..099e9ce 100644
--- a/gunicorn.py
+++ b/other/gunicorn.py
@@ -1,3 +1,8 @@
+"""
+This file is used to run certain tasks when the HTTP server starts.
+It's located here so it doesn't get imported with daemon.py
+"""
+
try:
import gevent.monkey
diff --git a/other/nginx-site.conf b/other/nginx-site.conf
new file mode 100644
index 0000000..1c81d3d
--- /dev/null
+++ b/other/nginx-site.conf
@@ -0,0 +1,38 @@
+server
+{
+ listen 443 ssl http2 default_server;
+ server_name _;
+
+ proxy_set_header Host $host;
+ proxy_set_header Connection $http_connection;
+ proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
+ proxy_set_header X-Scheme $scheme;
+
+ location ~* ^/api/(.*?|v1|openai)/(v1|(generate|stream)|(chat/completions|completions))$
+ {
+ # Route to inference endpoints
+ proxy_pass http://127.0.0.1:5000;
+
+ # Required for streaming (both websockets and SSE).
+ proxy_buffering off;
+ proxy_cache off;
+ proxy_http_version 1.1;
+ proxy_set_header Upgrade $http_upgrade;
+ proxy_set_header Connection "upgrade";
+
+ # Set long timeouts for inference operations.
+ # Cloudflare has a timeout of 100 seconds.
+ proxy_read_timeout 120;
+ proxy_connect_timeout 120;
+ proxy_send_timeout 120;
+ }
+
+ location /
+ {
+ proxy_pass http://127.0.0.1:5000;
+ }
+
+ ssl_certificate /etc/ssl/certs/nginx-selfsigned.crt;
+ ssl_certificate_key /etc/ssl/private/nginx-selfsigned.key;
+ include /etc/nginx/snippets/ssl-params.conf;
+}
diff --git a/other/tests/config.sh b/other/tests/config.sh
new file mode 100644
index 0000000..64bea46
--- /dev/null
+++ b/other/tests/config.sh
@@ -0,0 +1,11 @@
+HOST="proxy.chub-archive.evulid.cc"
+
+AUTH_KEY="TEST_1df979f0-6df1-41bd-814a-e99b1680e727"
+
+PROXY_SERVERS=(
+ "http://172.0.4.7:3128"
+ "http://172.0.4.8:3128"
+ "http://172.0.4.10:3128"
+ "http://172.0.4.12:3128"
+ "http://172.0.4.13:3128"
+)
diff --git a/other/tests/generate.sh b/other/tests/generate.sh
new file mode 100755
index 0000000..b1443c0
--- /dev/null
+++ b/other/tests/generate.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+SLEEP_TIME=2
+
+while getopts p:t: flag; do
+ case "${flag}" in
+ p) PROXY_CHOICE=${OPTARG} ;;
+ t) SLEEP_TIME=${OPTARG} ;;
+ *) ;;
+ esac
+done
+
+SOURCE=${BASH_SOURCE[0]}
+while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
+ DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
+ SOURCE=$(readlink "$SOURCE")
+ [[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
+done
+DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
+
+source "$DIR/config.sh"
+
+if [ -n "$PROXY_CHOICE" ]; then
+ our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
+ echo "Using $our_proxy_server"
+else
+ our_proxy_server=""
+fi
+
+while true; do
+ echo "--> START <--"
+
+ DATA=$(
+ cat < DONE <--\n"
+ sleep $SLEEP_TIME
+done
diff --git a/other/tests/oai-chat-completion.sh b/other/tests/oai-chat-completion.sh
new file mode 100755
index 0000000..5355a8a
--- /dev/null
+++ b/other/tests/oai-chat-completion.sh
@@ -0,0 +1,52 @@
+#!/bin/bash
+
+DO_STREAM=false
+SLEEP_TIME=2
+
+while getopts p:t:s flag; do
+ case "${flag}" in
+ s) DO_STREAM=true ;;
+ p) PROXY_CHOICE=${OPTARG} ;;
+ t) SLEEP_TIME=${OPTARG} ;;
+ *) ;;
+ esac
+done
+
+SOURCE=${BASH_SOURCE[0]}
+while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
+ DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
+ SOURCE=$(readlink "$SOURCE")
+ [[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
+done
+DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
+
+source "$DIR/config.sh"
+
+if [ ! -z "$PROXY_CHOICE" ]; then
+ our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
+ echo "Using $our_proxy_server"
+else
+ our_proxy_server=""
+fi
+
+while true; do
+ echo "--> START <--"
+
+ DATA=$(
+ cat < DONE <--\n"
+ sleep $SLEEP_TIME
+done
diff --git a/other/tests/oai-completion.sh b/other/tests/oai-completion.sh
new file mode 100755
index 0000000..cc0f9f0
--- /dev/null
+++ b/other/tests/oai-completion.sh
@@ -0,0 +1,52 @@
+#!/bin/bash
+
+DO_STREAM=false
+SLEEP_TIME=2
+
+while getopts p:t:s flag; do
+ case "${flag}" in
+ s) DO_STREAM=true ;;
+ p) PROXY_CHOICE=${OPTARG} ;;
+ t) SLEEP_TIME=${OPTARG} ;;
+ *) ;;
+ esac
+done
+
+SOURCE=${BASH_SOURCE[0]}
+while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
+ DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
+ SOURCE=$(readlink "$SOURCE")
+ [[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
+done
+DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
+
+source "$DIR/config.sh"
+
+if [ ! -z "$PROXY_CHOICE" ]; then
+ our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
+ echo "Using $our_proxy_server"
+else
+ our_proxy_server=""
+fi
+
+while true; do
+ echo "--> START <--"
+
+ DATA=$(
+ cat < DONE <--\n"
+ sleep $SLEEP_TIME
+done
diff --git a/other/tests/start-bulk.sh b/other/tests/start-bulk.sh
new file mode 100755
index 0000000..6f254d5
--- /dev/null
+++ b/other/tests/start-bulk.sh
@@ -0,0 +1,64 @@
+#!/bin/bash
+
+# Function to display help message
+function display_help {
+ echo "Usage: $0 -n num_windows -c command"
+ echo
+ echo " -n, --number Number of windows to create"
+ echo " -c, --command Command to run in each window"
+ echo
+ exit 1
+}
+
+# Parse command line arguments
+while getopts "n:c:h" opt; do
+ case ${opt} in
+ n)
+ num_windows=${OPTARG}
+ ;;
+ c)
+ command=${OPTARG}
+ ;;
+ h)
+ display_help
+ ;;
+ \?)
+ echo "Invalid option: -$OPTARG" 1>&2
+ display_help
+ ;;
+ :)
+ echo "Option -$OPTARG requires an argument." 1>&2
+ display_help
+ ;;
+ esac
+done
+
+# Check if number of windows and command are provided
+if [ -z "$num_windows" ] || [ -z "$command" ]; then
+ echo "Both number of windows and command are required."
+ display_help
+fi
+
+# Calculate rows and columns
+rows=$(echo "sqrt($num_windows)" | bc)
+columns=$(echo "($num_windows + $rows - 1) / $rows" | bc)
+
+# Create a new tmux session
+tmux new-session -d -s llm_tester "$command -p 0"
+
+# Create the remaining windows
+for ((i = 1; i < $num_windows; i++)); do
+ if ((i % $columns == 0)); then
+ tmux select-layout -t llm_tester:0 tiled
+ tmux select-pane -t 0
+ tmux split-window -t llm_tester:0 -v "$command -p $i"
+ else
+ tmux split-window -t llm_tester:0 -h "$command -p $i"
+ fi
+done
+
+# Balance the windows
+tmux select-layout -t llm_tester:0 tiled
+
+# Attach to the session
+tmux attach-session -t llm_tester
diff --git a/other/ooba-test-streaming.py b/other/tests/stream.py
old mode 100644
new mode 100755
similarity index 52%
rename from other/ooba-test-streaming.py
rename to other/tests/stream.py
index 883c2f5..75d403b
--- a/other/ooba-test-streaming.py
+++ b/other/tests/stream.py
@@ -1,37 +1,50 @@
import asyncio
import json
import sys
+import os
+import time
+from pathlib import Path
try:
import websockets
except ImportError:
print("Websockets package not found. Make sure it's installed.")
-# For local streaming, the websockets are hosted without ssl - ws://
-HOST = 'localhost:5000'
-URI = f'ws://{HOST}/api/v1/stream'
+script_path = os.path.dirname(os.path.realpath(__file__))
-# For reverse-proxied streaming, the remote will likely host with ssl - wss://
-# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
+
+def parse_bash_config(file_path):
+ config = {}
+ with open(file_path, 'r') as f:
+ for line in f:
+ if line.startswith('#') or '=' not in line:
+ continue
+ key, value = line.strip().split('=', 1)
+ if value.startswith('"') and value.endswith('"'):
+ value = value[1:-1]
+ elif value.startswith('(') and value.endswith(')'):
+ value = value[1:-1].split()
+ value = [v.strip('"') for v in value]
+ config[key] = value
+ return config
+
+
+config = parse_bash_config(Path(script_path, 'config.sh'))
async def run(context):
- # Note: the selected defaults change from time to time.
request = {
'prompt': context,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
-
- # Generation params. If 'preset' is set to different than 'None', the values
- # in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
- 'epsilon_cutoff': 0, # In units of 1e-4
- 'eta_cutoff': 0, # In units of 1e-4
+ 'epsilon_cutoff': 0,
+ 'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
@@ -48,7 +61,6 @@ async def run(context):
'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
-
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
@@ -58,7 +70,7 @@ async def run(context):
'stopping_strings': []
}
- async with websockets.connect(URI, ping_interval=None) as websocket:
+ async with websockets.connect(f'wss://{config["HOST"]}/api/v1/stream', ping_interval=None) as websocket:
await websocket.send(json.dumps(request))
yield context # Remove this if you just want to see the reply
@@ -67,20 +79,28 @@ async def run(context):
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
+ print(incoming_data)
+
match incoming_data['event']:
- case 'text_stream':
- yield incoming_data['text']
+ # case 'text_stream':
+ # yield incoming_data['text']
case 'stream_end':
return
async def print_response_stream(prompt):
- async for response in run(prompt):
- print(response, end='')
- sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
- print('\n\nfinished')
+ try:
+ async for response in run(prompt):
+ print(response, end='')
+ sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
+ except Exception as e:
+ print(e)
if __name__ == '__main__':
- prompt = "In order to make homemade bread, follow these steps:\n1)"
- asyncio.run(print_response_stream(prompt))
+ prompt = "Write a 300 word story about an apple tree.\n\n"
+ while True:
+ print('--> START <--')
+ asyncio.run(print_response_stream(prompt))
+ print('--> DONE <--')
+ time.sleep(2)
diff --git a/other/vllm/Docker/DOCKER.md b/other/vllm/Docker/DOCKER.md
deleted file mode 100644
index 6abf6bf..0000000
--- a/other/vllm/Docker/DOCKER.md
+++ /dev/null
@@ -1,15 +0,0 @@
-**A Docker container for running VLLM on Paperspace Gradient notebooks.**
-
-1. Run `jupyter server --generate-config` and `jupyter server password` on your local machine, then copy Jupyter's config directory to `./jupyter`
-2. Place your Rathole client config at `./rathole-client.toml`
-3. `docker build . -t "paperspace-vllm"`
-
-To test on your local machine, run this command:
-
-```bash
-docker run --shm-size 14g --gpus all \
- -v /storage/models/awq/MythoMax-L2-13B-AWQ:/models/MythoMax-L2-13B-AWQ \
- -p 7000:7000 -p 8888:8888 \
- -e API_SERVER_ARGS="--model /models/MythoMax-L2-13B-AWQ --quantization awq --max-num-batched-tokens 99999 --gpu-memory-utilization 1" \
- vllm-cloud
-```
\ No newline at end of file
diff --git a/other/vllm/Docker/Dockerfile b/other/vllm/Docker/Dockerfile
index d3c02e8..7ebe7b0 100644
--- a/other/vllm/Docker/Dockerfile
+++ b/other/vllm/Docker/Dockerfile
@@ -1,87 +1,50 @@
-FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as build
-
-RUN apt-get update && \
- apt-get install -y git python3-pip python3-venv wget unzip && \
- rm -rf /var/lib/apt/lists/*
-RUN pip3 install --upgrade pip setuptools wheel
-
-RUN git clone https://git.evulid.cc/cyberes/local-llm-server.git /local-llm-server
-
-WORKDIR /local-llm-server
-
-RUN python3 -m venv /venv
-RUN /venv/bin/pip install git+https://github.com/vllm-project/vllm
-
-RUN python3 -m venv /jupyterlab
-RUN /jupyterlab/bin/pip install jupyterlab
-RUN /jupyterlab/bin/jupyter labextension disable "@jupyterlab/apputils-extension:announcements"
-
-RUN mkdir -p /app
-RUN wget https://github.com/rapiz1/rathole/releases/download/v0.4.8/rathole-x86_64-unknown-linux-gnu.zip -O /tmp/rathole.zip
-RUN unzip -j /tmp/rathole.zip -d /tmp
-RUN rm /tmp/rathole.zip
-RUN cp /tmp/rathole /app
-
-# The local local-llm-server repo may be cached, so we will fetch and reset to the remote every time.
-# Also, make sure there weren't any pip deps added.
-ADD "https://www.random.org/cgi-bin/randbyte?nbytes=10&format=h" skipcache
-RUN git fetch; git reset --hard origin/master
-RUN /venv/bin/pip install -r requirements.txt
-
-FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as runtime
-
-RUN apt-get update && apt-get install -y supervisor && rm -rf /var/lib/apt/lists/*
+FROM cyberes/vllm-paperspace-base as runtime
RUN useradd -ms /bin/bash apiserver
RUN usermod -s /bin/bash root
+# Required packages
RUN apt-get update && \
- apt-get install -y python3 python3-pip wget aria2 git-lfs git openssh-server openssh-client nano tmux file && \
+ apt-get install -y python3 python3-pip supervisor && \
+ rm -rf /var/lib/apt/lists/*
+RUN pip3 install --upgrade pip setuptools wheel
+
+# Useful Python packages
+RUN pip3 install glances
+
+# Useful tools
+RUN apt-get update && \
+ apt-get install -y wget aria2 git-lfs git openssh-server openssh-client nano tmux file && \
rm -rf /var/lib/apt/lists/*
-RUN pip3 install --upgrade pip setuptools wheel
-RUN pip3 install glances
+# Update the git repo
+RUN cd /local-llm-server && git reset --hard && git pull
# Enable root SSH login
RUN sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config
-
# Disable password SSH login
RUN sed -i 's/#PasswordAuthentication yes/PasswordAuthentication no/' /etc/ssh/sshd_config
-
-# Create the necessary directory for SSH
+# Create the necessary directory for sshd
RUN mkdir /var/run/sshd
-ADD "https://www.random.org/cgi-bin/randbyte?nbytes=10&format=h" skipcache
-
-COPY --from=build /local-llm-server /local-llm-server
-COPY --from=build /venv /venv
-COPY --from=build /app /app
-COPY --from=build /jupyterlab /jupyterlab
-
-RUN cp /local-llm-server/other/vllm/Docker/supervisord.conf /etc/supervisor/conf.d/supervisord.conf
-RUN cp /local-llm-server/other/vllm/Docker/start-vllm.sh /app/start-vllm.sh
-RUN cp /local-llm-server/other/vllm/Docker/start-container.sh /app/start.sh
-
-# Copy your secrets in
-# COPY ./jupyter /app/jupyter
+COPY supervisord.conf /etc/supervisor/supervisord.conf
+COPY start-vllm.sh /app/start-vllm.sh
+COPY init-container.sh /app/init.sh
+COPY start-container.sh /app/start.sh
RUN mkdir -p /var/log/app/
RUN chown -R apiserver:apiserver /local-llm-server && \
chown -R apiserver:apiserver /app && \
chown -R apiserver:apiserver /var/log/app/
+RUN git config --global --add safe.directory /local-llm-server
+RUN chmod +x /app/init.sh
RUN chmod +x /app/start.sh
ENV SHELL="/bin/bash"
-# SSH
-EXPOSE 22
-
-# VLLM
-EXPOSE 7000
-
-# Jupyter
+# Expose Jupyter. We don't need to expose VLLM or SSH since rathole will tunnel those.
EXPOSE 8888
CMD /app/start.sh
diff --git a/other/vllm/Docker/Dockerfile.base b/other/vllm/Docker/Dockerfile.base
new file mode 100644
index 0000000..bcd4d6f
--- /dev/null
+++ b/other/vllm/Docker/Dockerfile.base
@@ -0,0 +1,43 @@
+# This container builds and assembles the Python parts of the Docker container.
+# It is used as the base for the resulting container, which avoids having to re-push
+# the large PyTorch parts every time the application is rebuilt.
+
+FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as build
+
+RUN apt-get update && \
+ apt-get install -y git python3-pip python3-venv wget unzip && \
+ rm -rf /var/lib/apt/lists/*
+RUN pip install --upgrade pip setuptools wheel
+
+RUN git clone https://git.evulid.cc/cyberes/local-llm-server.git /local-llm-server
+
+RUN python3 -m venv /jupyterlab
+RUN /jupyterlab/bin/pip install jupyterlab
+RUN /jupyterlab/bin/jupyter labextension disable "@jupyterlab/apputils-extension:announcements"
+
+RUN mkdir -p /app
+RUN wget https://github.com/rapiz1/rathole/releases/download/v0.4.8/rathole-x86_64-unknown-linux-gnu.zip -O /tmp/rathole.zip
+RUN unzip -j /tmp/rathole.zip -d /tmp
+RUN rm /tmp/rathole.zip
+RUN cp /tmp/rathole /app
+
+RUN python3 -m venv /venv
+RUN /venv/bin/pip3 install --upgrade pip setuptools wheel
+
+# Install PyTorch before installing VLLM to ensure we use the right version for our CUDA install.
+RUN wget -q -O - https://raw.githubusercontent.com/vllm-project/vllm/main/requirements.txt | grep -E 'torch*' > /tmp/torch_version
+RUN /venv/bin/pip3 install "$(cat /tmp/torch_version)" --index-url https://download.pytorch.org/whl/cu118
+
+# WORKDIR /local-llm-server
+
+# Don't build VLLM because we don't do that on the inference server. Just install from pip.
+# RUN /venv/bin/pip install git+https://github.com/vllm-project/vllm
+
+RUN /venv/bin/pip install vllm
+
+FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as base
+
+COPY --from=build /local-llm-server /local-llm-server
+COPY --from=build /venv /venv
+COPY --from=build /app /app
+COPY --from=build /jupyterlab /jupyterlab
diff --git a/other/vllm/Docker/README.md b/other/vllm/Docker/README.md
new file mode 100644
index 0000000..97faf32
--- /dev/null
+++ b/other/vllm/Docker/README.md
@@ -0,0 +1,47 @@
+**A Docker container for running VLLM on Paperspace Gradient notebooks.**
+
+### Running
+
+1. In Paperspace, create a new notebook.
+2. Click `Start from Scratch`.
+3. Select your GPU and set the auto-shutdown timeout to 6 hours.
+4. Click the `View Advanced Options` button at the bottom of the page. Enter these details in the form that appears:
+ - Container Name: `cyberes/vllm-paperspace:latest`
+ - Container Command: `/app/start.sh`
+5. Start the notebook. It may take up to five minutes for them to pull and start the custom image.
+6. Once the container is started, open the log viewer by clicking the icon in the bottom left of the screen. You should see errors from rathole and VLLM as a result of the blank config files. The container will create a new directory in your mounted
+ storage: `/storage/vllm/`.
+7. Enter your rathole client config in `/storage/vllm/rathole-client.toml`. If you need a visual text editor, first link the directory back to the Jupyter home: `ln -s /storage/vllm /notebooks`
+8. Restart rathole with `supervisorctl restart rathole` and then view the log: `tail -f /var/log/app/rathole.log`. If you see lines that start with `INFO` and end with `Control channel established`, rathole has connected and is working. Error mesasges will begin
+ with `ERROR`.
+9. Download an AWQ quantization from [TheBloke](https://huggingface.co/TheBloke) to `/storage/vllm/models/`.
+10. Enter your VLLM commandline args in `/storage/vllm/cmd.txt`. You need to set `--model` to the path of the model you want to load.
+11. Restart VLLM with `supervisorctl restart vllm` and then view the log: `tail -f /var/log/app/vllm.log`. It may take up to three minutes to load. When you see the line:
+ ```
+ INFO: Uvicorn running on http://0.0.0.0:7000 (Press CTRL+C to quit)
+ ```
+ VLLM is running and ready for queries.
+
+12. In `/notebooks` (the home directory of Jupyter), the notebook `idle.ipynb` will automatically be created. Run this notebook so Paperspace does not shut down your machine due to "inactivity". You **must** keep the running notebook open in a
+ browser tab.
+
+### Building
+
+You **must** have a GPU attached to your system when building the container (required for building VLLM).
+
+1. Install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and CUDA 11.8.
+2. `bash build-docker.sh`
+
+To run the container on your local machine:
+
+```bash
+sudo docker run -it --shm-size 14g --gpus all -v /home/user/testing123/notebooks:/notebooks -v /home/user/testing123/storage:/storage -p 8888:8888 cyberes/vllm-paperspace:latest
+```
+
+You will need to create a directory to mount inside the container (for example: `/home/user/testing123/`). Within this should be the folder `models` that holds the model to load, `rathole-client.toml`, and `cmd.txt`.
+
+If you need to debug something, you can start a shell inside the container:
+
+```bash
+sudo docker run -it --shm-size 14g --gpus all -v /home/user/testing123/notebooks:/notebooks -v /home/user/testing123/storage:/storage -p 8888:8888 --entrypoint bash cyberes/vllm-paperspace:latest
+```
diff --git a/other/vllm/Docker/build-docker.sh b/other/vllm/Docker/build-docker.sh
new file mode 100644
index 0000000..f95ad4f
--- /dev/null
+++ b/other/vllm/Docker/build-docker.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+# Build and push the container.
+
+git pull || exit
+sudo docker build . -f Dockerfile.base -t cyberes/vllm-paperspace-base --no-cache && sudo docker push cyberes/vllm-paperspace-base:latest || exit
+sudo docker build . -t cyberes/vllm-paperspace && sudo docker push cyberes/vllm-paperspace:latest
diff --git a/other/vllm/Docker/idle.ipynb b/other/vllm/Docker/idle.ipynb
new file mode 100644
index 0000000..057e227
--- /dev/null
+++ b/other/vllm/Docker/idle.ipynb
@@ -0,0 +1,40 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "49ae6555-572b-4463-ba01-cc4331932a6c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import time\n",
+ "i = 0\n",
+ "while True:\n",
+ " print(i)\n",
+ " i += 1\n",
+ " time.sleep(1)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/other/vllm/Docker/init-container.sh b/other/vllm/Docker/init-container.sh
new file mode 100644
index 0000000..111646c
--- /dev/null
+++ b/other/vllm/Docker/init-container.sh
@@ -0,0 +1,22 @@
+#!/bin/bash
+
+# Create the required directories and files.
+echo "SETTING UP FILE SYSTEM..."
+mkdir -p /storage/vllm/
+chown -R apiserver:apiserver /storage/vllm
+touch /storage/vllm/cmd.txt
+touch /storage/vllm/rathole-client.toml
+
+# The user can store SSH auth and authorized_keys to streamline SSH login.
+if [ -f /storage/vllm/ssh ]; then
+ cp -r /storage/vllm/ssh /root/.ssh
+ echo "Copied ssh from /storage"
+fi
+
+# If the user has not created the VLLM commandline arg file, create the default.
+if [ ! -f /storage/vllm/cmd.txt ]; then
+ echo "--max-num-batched-tokens 4098 --quantization awq --model /storage/vllm/models/model-path" >/storage/vllm/cmd.txt
+fi
+
+# Copy the idling notebook to storage. This will create a blank notebook every time the container is started.
+cp /local-llm-server/other/vllm/Docker/idle.ipynb /notebooks/idle.ipynb
diff --git a/other/vllm/Docker/start-container.sh b/other/vllm/Docker/start-container.sh
index 0b98702..05587a1 100644
--- a/other/vllm/Docker/start-container.sh
+++ b/other/vllm/Docker/start-container.sh
@@ -1,13 +1,4 @@
#!/bin/bash
-mkdir -p /storage/vllm/
-chown -R apiserver:apiserver /storage/vllm
-touch /storage/vllm/cmd.txt
-touch /storage/vllm/rathole-client.toml
-
-if [ -f /storage/vllm/ssh ]; then
- cp -r /storage/vllm/ssh /root/.ssh
- echo "Copied ssh from /storage"
-fi
-
-/usr/bin/supervisord
+# Start the services and launch the container.
+/usr/bin/supervisord -c /etc/supervisor/supervisord.conf
diff --git a/other/vllm/Docker/start-vllm.sh b/other/vllm/Docker/start-vllm.sh
index 906bc30..209e90a 100644
--- a/other/vllm/Docker/start-vllm.sh
+++ b/other/vllm/Docker/start-vllm.sh
@@ -6,9 +6,4 @@ for pid in $vllm_pid; do
kill -9 $pid
done
-cd /local-llm-server
-git fetch
-git reset --hard origin/master
-/venv/bin/pip install -r requirements.txt
-
/venv/bin/python /local-llm-server/other/vllm/vllm_api_server.py --host 0.0.0.0 --port 7000 --max-log-len 100 $(cat /storage/vllm/cmd.txt)
diff --git a/other/vllm/Docker/supervisord.conf b/other/vllm/Docker/supervisord.conf
index 9361bdb..800cb27 100644
--- a/other/vllm/Docker/supervisord.conf
+++ b/other/vllm/Docker/supervisord.conf
@@ -1,5 +1,25 @@
[supervisord]
-nodaemon=true
+nodaemon = true
+user=root
+pidfile = /var/run/supervisord.pid
+logfile = /var/log/app/supervisord.log
+directory = /tmp
+
+[unix_http_server]
+file=/var/run/supervisor.sock
+chmod=0770
+
+[rpcinterface:supervisor]
+supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface
+
+[supervisorctl]
+serverurl=unix:///var/run/supervisor.sock
+
+[program:startup]
+command=/app/init.sh
+autostart=true
+autorestart=false
+startsecs=0
[program:vllm]
command=/bin/bash -c 'bash /app/start-vllm.sh 2>&1 | tee -a /var/log/app/vllm.log'
@@ -24,9 +44,20 @@ user=apiserver
environment=HOME="/home/apiserver",USER="apiserver"
[program:jupyter]
-command=/jupyterlab/bin/jupyter lab --allow-root --ip=0.0.0.0 --no-browser --ServerApp.trust_xheaders=True --ServerApp.disable_check_xsrf=False --ServerApp.allow_remote_access=True --ServerApp.allow_origin='*' --ServerApp.allow_credentials=True
+command=/jupyterlab/bin/jupyter lab --allow-root --ip=0.0.0.0 --no-browser --ServerApp.trust_xheaders=True --ServerApp.disable_check_xsrf=False --ServerApp.allow_remote_access=True --ServerApp.allow_origin='*' --ServerApp.allow_credentials=True --notebook-dir /notebooks
environment=SHELL="/bin/bash"
-; JUPYTER_CONFIG_DIR="/app/jupyter"
+autostart=true
+autorestart=true
+stdout_logfile=/dev/fd/1
+stdout_logfile_maxbytes=0
+stderr_logfile=/dev/fd/2
+stderr_logfile_maxbytes=0
[program:ssh]
command=/usr/sbin/sshd -D
+autostart=true
+autorestart=true
+stdout_logfile=/dev/fd/1
+stdout_logfile_maxbytes=0
+stderr_logfile=/dev/fd/2
+stderr_logfile_maxbytes=0
diff --git a/other/vllm/Docker/update-container.sh b/other/vllm/Docker/update-container.sh
new file mode 100755
index 0000000..d44d6d9
--- /dev/null
+++ b/other/vllm/Docker/update-container.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+# Run this script to update the container.
+# Will restart VLLM as well.
+
+cd /local-llm-server || exit
+
+git fetch
+git reset --hard origin/master
+
+supervisorctl restart vllm
diff --git a/requirements.txt b/requirements.txt
index 9b0c8eb..89f4be7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,21 +1,18 @@
flask~=2.3.3
-flask_cors
pyyaml~=6.0.1
-flask_caching
+Flask-Caching==2.0.2
requests~=2.31.0
tiktoken~=0.5.0
-gunicorn
gevent~=23.9.0.post1
-async-timeout
-flask-sock
-uvicorn~=0.23.2
-fastapi~=0.103.1
-torch~=2.0.1
PyMySQL~=1.1.0
-DBUtils~=3.0.3
simplejson~=3.19.1
websockets~=11.0.3
basicauth~=1.0.0
openai~=0.28.0
-urllib3~=2.0.4
-celery[redis]
+flask-sock==0.6.0
+gunicorn==21.2.0
+redis==5.0.1
+ujson==5.8.0
+vllm==0.2.1.post1
+gradio~=3.46.1
+coloredlogs~=15.0.1
\ No newline at end of file
diff --git a/server.py b/server.py
index 06482d4..aa8ef1a 100644
--- a/server.py
+++ b/server.py
@@ -1,5 +1,3 @@
-from llm_server.config.config import mode_ui_names
-
try:
import gevent.monkey
@@ -7,37 +5,46 @@ try:
except ImportError:
pass
-from llm_server.pre_fork import server_startup
-from llm_server.config.load import load_config
import os
import sys
from pathlib import Path
import simplejson as json
-from flask import Flask, jsonify, render_template, request
+from flask import Flask, jsonify, render_template, request, Response
-import llm_server
+import config
+from llm_server import opts
+from llm_server.cluster.backend import get_model_choices
+from llm_server.cluster.cluster_config import cluster_config
+from llm_server.config.config import mode_ui_names
+from llm_server.config.load import load_config
+from llm_server.custom_redis import flask_cache, redis
from llm_server.database.conn import database
from llm_server.database.create import create_db
-from llm_server.llm import get_token_count
-from llm_server.routes.openai import openai_bp
+from llm_server.helpers import auto_set_base_client_api
+from llm_server.llm.vllm.info import vllm_info
+from llm_server.pre_fork import server_startup
+from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
-from llm_server.stream import init_socketio
+from llm_server.routes.v1.generate_stats import generate_stats
+from llm_server.sock import init_wssocket
-# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
-# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
-# TODO: implement background thread to test backends via sending test prompts
-# TODO: if backend fails request, mark it as down
-# TODO: allow setting concurrent gens per-backend
-# TODO: set the max tokens to that of the lowest backend
-# TODO: implement RRD backend loadbalancer option
-# TODO: have VLLM reject a request if it already has n == concurrent_gens running
-# TODO: add a way to cancel VLLM gens. Maybe use websockets?
-# TODO: use coloredlogs
-# TODO: need to update opts. for workers
+# TODO: seperate queue item timeout for websockets (make longer, like 5 minutes)
+# TODO: return an `error: True`, error code, and error message rather than just a formatted message
+# TODO: what happens when all backends are offline? What about the "online" key in the stats page?
+# TODO: redis SCAN vs KEYS??
+# TODO: is frequency penalty the same as ooba repetition penalty???
+# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
# Lower priority
+# TODO: if a backend is at its limit of concurrent requests, choose a different one
+# TODO: make error messages consitient
+# TODO: support logit_bias on OpenAI and Ooba endpoints.
+# TODO: add a way to cancel VLLM gens. Maybe use websockets?
+# TODO: validate openai_silent_trim works as expected and only when enabled
+# TODO: rewrite config storage. Store in redis so we can reload it.
+# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
# TODO: the estiamted wait time lags behind the stats
# TODO: simulate OpenAI error messages regardless of endpoint
@@ -59,19 +66,16 @@ except ModuleNotFoundError as e:
print('Please see README.md for install instructions.')
sys.exit(1)
-import config
-from llm_server import opts
-from llm_server.helpers import auto_set_base_client_api
-from llm_server.llm.vllm.info import vllm_info
-from llm_server.routes.cache import RedisWrapper, flask_cache
-from llm_server.llm import redis
-from llm_server.routes.stats import get_active_gen_workers
-from llm_server.routes.v1.generate_stats import generate_stats
-
app = Flask(__name__)
-init_socketio(app)
-app.register_blueprint(bp, url_prefix='/api/v1/')
+
+# Fixes ConcurrentObjectUseError
+# https://github.com/miguelgrinberg/simple-websocket/issues/24
+app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
+
+app.register_blueprint(bp, url_prefix='/api/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
+app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
+init_wssocket(app)
flask_cache.init_app(app)
flask_cache.clear()
@@ -82,18 +86,13 @@ if config_path_environ:
else:
config_path = Path(script_path, 'config', 'config.yml')
-success, config, msg = load_config(config_path, script_path)
+success, config, msg = load_config(config_path)
if not success:
print('Failed to load config:', msg)
sys.exit(1)
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db()
-llm_server.llm.redis = RedisWrapper('local_llm')
-create_db()
-
-
-# print(app.url_map)
@app.route('/')
@@ -101,20 +100,30 @@ create_db()
@app.route('/api/openai')
@flask_cache.cached(timeout=10)
def home():
+ base_client_api = redis.get('base_client_api', dtype=str)
stats = generate_stats()
+ model_choices, default_model = get_model_choices()
- if not stats['online']:
- running_model = estimated_wait_sec = 'offline'
- else:
- running_model = redis.get('running_model', str, 'ERROR')
+ if default_model:
+ if not model_choices.get(default_model):
+ return 'The server is still starting up. Please wait...'
- active_gen_workers = get_active_gen_workers()
- if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
+ default_model_info = model_choices[default_model]
+
+ if default_model_info['queued'] == 0 and default_model_info['queued'] >= default_model_info['concurrent_gens']:
# There will be a wait if the queue is empty but prompts are processing, but we don't
# know how long.
- estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"
+ default_estimated_wait_sec = f"less than {int(default_model_info['estimated_wait'])} seconds"
else:
- estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds"
+ default_estimated_wait_sec = f"{int(default_model_info['estimated_wait'])} seconds"
+ else:
+ default_model_info = {
+ 'model': 'OFFLINE',
+ 'processing': '-',
+ 'queued': '-',
+ 'context_size': '-',
+ }
+ default_estimated_wait_sec = 'OFFLINE'
if len(config['analytics_tracking_code']):
analytics_tracking_code = f""
@@ -127,32 +136,47 @@ def home():
info_html = ''
mode_info = ''
- if opts.mode == 'vllm':
- mode_info = vllm_info
-
- base_client_api = redis.get('base_client_api', str)
+ for k, v in cluster_config.all().items():
+ if v['mode'] == 'vllm':
+ mode_info = vllm_info
+ break
return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name,
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
- current_model=opts.manual_model_name if opts.manual_model_name else running_model,
+ default_model=default_model_info['model'],
+ default_active_gen_workers=default_model_info['processing'],
+ default_proompters_in_queue=default_model_info['queued'],
+ current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
client_api=f'https://{base_client_api}',
- ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
- estimated_wait=estimated_wait_sec,
- mode_name=mode_ui_names[opts.mode][0],
- api_input_textbox=mode_ui_names[opts.mode][1],
- streaming_input_textbox=mode_ui_names[opts.mode][2],
- context_size=opts.context_size,
+ ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled',
+ default_estimated_wait=default_estimated_wait_sec,
+ mode_name=mode_ui_names[opts.frontend_api_mode][0],
+ api_input_textbox=mode_ui_names[opts.frontend_api_mode][1],
+ streaming_input_textbox=mode_ui_names[opts.frontend_api_mode][2],
+ default_context_size=default_model_info['context_size'],
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info,
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
expose_openai_system_prompt=opts.expose_openai_system_prompt,
enable_streaming=opts.enable_streaming,
+ model_choices=model_choices,
+ proompters_5_min=stats['stats']['proompters']['5_min'],
+ proompters_24_hrs=stats['stats']['proompters']['24_hrs'],
)
-# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
+@app.route('/robots.txt')
+def robots():
+ # TODO: have config value to deny all
+ # TODO: https://developers.google.com/search/docs/crawling-indexing/robots/create-robots-txt
+ t = """User-agent: *
+Allow: /"""
+ r = Response(t)
+ r.headers['Content-Type'] = 'text/plain'
+ return r
+
@app.route('/')
@app.route('//')
diff --git a/templates/home.html b/templates/home.html
index 4b9c153..66340a6 100644
--- a/templates/home.html
+++ b/templates/home.html
@@ -65,6 +65,19 @@
.hidden {
display: none;
}
+
+ .header-workers {
+ font-weight: normal;
+ font-size: 14pt;
+ }
+
+ h3 {
+ font-size: 16pt;
+ }
+
+ .no-marker {
+ list-style: none;
+ }
@@ -76,8 +89,12 @@
{{ llm_middleware_name }}
-
Current Model: {{ current_model }}
-
Estimated Wait Time: {{ estimated_wait }}
+
Current Model: {{ default_model }}
+
+ Estimated Wait Time: {{ default_estimated_wait }}
+ Processing: {{ default_active_gen_workers }}
+ Queued: {{ default_proompters_in_queue }}
+
Client API URL: {{ client_api }}
Streaming API URL: {{ ws_client_api if enable_streaming else 'Disabled' }}
@@ -91,17 +108,20 @@
-
-
Instructions:
+
Instructions
+
+ In Settings > Power User Options, enable Relaxed API URLS .
Set your API type to {{ mode_name }}
Enter {{ client_api }} in the {{ api_input_textbox }} textbox.
- {% if enable_streaming %}Enter {{ ws_client_api }} in the {{ streaming_input_textbox }} textbox. {% endif %}
+ {% if enable_streaming %}
+ Enter {{ ws_client_api }} in the {{ streaming_input_textbox }} textbox.
+ {% endif %}
If you have a token, check the Mancer AI checkbox and enter your token in the Mancer
API key textbox.
Click Connect to test the connection.
- Open your preset config and set Context Size to {{ context_size }}.
+ Open your preset config and set Context Size to {{ default_context_size }}.
Follow this guide to get set up: rentry.org/freellamas
@@ -120,13 +140,45 @@
-
{{ stats_json|safe }}
+
Statistics
+ Proompters:
+
+ 5 minutes: {{ proompters_5_min }}
+ 24 hours: {{ proompters_24_hrs }}
+
+
+
+ {% for key, value in model_choices.items() %}
+
+
{{ key }}
+
+ {% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
+ {# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
+ {% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %}
+ {% else %}
+ {% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %}
+ {% endif %}
+
+
+ Estimated Wait Time: {{ estimated_wait_sec }}
+ Processing: {{ value.processing }}
+ Queued: {{ value.queued }}
+
+
+ Client API URL: {{ value.client_api }}
+ Streaming API URL: {{ value.ws_client_api }}
+ OpenAI-Compatible API URL: {{ value.openai_client_api }}
+
+
Context Size: {{ value.context_size }}
+
Average Generation Time: {{ value.avg_generation_time | int }} seconds
+
+
+ {% endfor %}
-