Merge cluster to master (#3)

Co-authored-by: Cyberes <cyberes@evulid.cc>
Reviewed-on: #3
This commit is contained in:
Cyberes 2023-10-27 19:19:22 -06:00
parent 561820fb9e
commit 0059e7956c
80 changed files with 2911 additions and 1258 deletions

View File

@ -43,7 +43,9 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
### Use ### Use
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database.
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
### To Do ### To Do

View File

@ -1,22 +1,19 @@
import time import argparse
import logging
from llm_server.routes.cache import redis
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import os import os
import sys import sys
import time
from pathlib import Path from pathlib import Path
from llm_server.config.load import load_config from redis import Redis
from llm_server.database.create import create_db
from llm_server.workers.app import start_background from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.load import load_config, parse_backends
from llm_server.custom_redis import redis
from llm_server.database.create import create_db
from llm_server.logging import create_logger, logging_info, init_logging
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background
script_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH") config_path_environ = os.getenv("CONFIG_PATH")
@ -26,19 +23,46 @@ else:
config_path = Path(script_path, 'config', 'config.yml') config_path = Path(script_path, 'config', 'config.yml')
if __name__ == "__main__": if __name__ == "__main__":
flushed_keys = redis.flush() parser = argparse.ArgumentParser(description='Daemon microservice.')
print('Flushed', len(flushed_keys), 'keys from Redis.') parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.")
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
args = parser.parse_args()
success, config, msg = load_config(config_path, script_path) # TODO: have this be set by either the arg or a config value
if args.debug:
logging_info.level = logging.DEBUG
init_logging()
logger = create_logger('daemon')
logger.debug('Debug logging enabled.')
if not args.no_reset:
Redis().flushall()
logger.info('Flushed Redis.')
success, config, msg = load_config(config_path)
if not success: if not success:
print('Failed to load config:', msg) logger.info(f'Failed to load config: {msg}')
sys.exit(1) sys.exit(1)
create_db() create_db()
cluster_config.clear()
cluster_config.load(parse_backends(config))
logger.info('Loading backend stats...')
generate_stats(regen=True)
start_background() start_background()
redis.set('daemon_started', 1) # Give some time for the background threads to get themselves ready to go.
print('== Daemon Setup Complete ==\n') time.sleep(2)
while True: redis.set('daemon_started', 1)
time.sleep(3600) logger.info('== Daemon Setup Complete ==')
try:
while True:
time.sleep(3600)
except KeyboardInterrupt:
redis.set('daemon_started', 0)

View File

View File

@ -0,0 +1,117 @@
import numpy as np
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
from llm_server.cluster.stores import redis_running_models
from llm_server.custom_redis import redis
from llm_server.llm.generator import generator
from llm_server.llm.info import get_info
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model
def get_backends_from_model(model_name: str):
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
def get_running_models():
return redis_running_models.keys()
def purge_backend_from_running_models(backend_url: str):
keys = redis_running_models.keys()
pipeline = redis_running_models.pipeline()
for model in keys:
pipeline.srem(model, backend_url)
pipeline.execute()
def is_valid_model(model_name: str):
return redis_running_models.exists(model_name)
def test_backend(backend_url: str, test_prompt: bool = False):
backend_info = cluster_config.get_backend(backend_url)
if test_prompt:
handler = VLLMBackend(backend_url)
parameters, _ = handler.get_parameters({
"stream": False,
"temperature": 0,
"max_new_tokens": 3,
})
data = {
'prompt': 'test prompt',
**parameters
}
try:
success, response, err = generator(data, backend_url, timeout=10)
if not success or not response or err:
return False, {}
except:
return False, {}
i = get_info(backend_url, backend_info['mode'])
if not i.get('model'):
return False, {}
return True, i
def get_model_choices(regen: bool = False):
if not regen:
c = redis.getp('model_choices')
if c:
return c
base_client_api = redis.get('base_client_api', dtype=str)
running_models = get_running_models()
model_choices = {}
for model in running_models:
b = get_backends_from_model(model)
context_size = []
avg_gen_per_worker = []
concurrent_gens = 0
for backend_url in b:
backend_info = cluster_config.get_backend(backend_url)
if backend_info.get('model_config'):
context_size.append(backend_info['model_config']['max_position_embeddings'])
if backend_info.get('average_generation_elapsed_sec'):
avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec'])
concurrent_gens += backend_info['concurrent_gens']
active_gen_workers = get_active_gen_workers_model(model)
proompters_in_queue = priority_queue.len(model)
if len(avg_gen_per_worker):
average_generation_elapsed_sec = np.average(avg_gen_per_worker)
else:
average_generation_elapsed_sec = 0
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, concurrent_gens, active_gen_workers)
model_choices[model] = {
'model': model,
'client_api': f'https://{base_client_api}/{model}',
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled',
'backend_count': len(b),
'estimated_wait': estimated_wait_sec,
'queued': proompters_in_queue,
'processing': active_gen_workers,
'avg_generation_time': average_generation_elapsed_sec,
'concurrent_gens': concurrent_gens
}
if len(context_size):
model_choices[model]['context_size'] = min(context_size)
# Python wants to sort lowercase vs. uppercase letters differently.
model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper()))
default_backend_url = get_a_cluster_backend()
default_backend_info = cluster_config.get_backend(default_backend_url)
if not default_backend_info.get('model'):
return {}, None
default_model = default_backend_info['model']
redis.setp('model_choices', (model_choices, default_model))
return model_choices, default_model

View File

@ -0,0 +1,124 @@
import hashlib
import pickle
import traceback
from llm_server import opts
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
from llm_server.cluster.stores import redis_running_models
from llm_server.custom_redis import RedisCustom
from llm_server.routes.helpers.model import estimate_model_size
class RedisClusterStore:
def __init__(self, name: str, **kwargs):
self.name = name
self.config_redis = RedisCustom(name, **kwargs)
def clear(self):
self.config_redis.flush()
def load(self, config: dict):
for k, v in config.items():
self.add_backend(k, v)
def add_backend(self, name: str, values: dict):
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
self.set_backend_value(name, 'online', False)
h = hashlib.sha256(name.encode('utf-8')).hexdigest()
self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}')
def set_backend_value(self, backend: str, key: str, value):
# By storing the value as a pickle we don't have to cast anything when getting the value from Redis.
self.config_redis.hset(backend, key, pickle.dumps(value))
def get_backend(self, name: str):
r = self.config_redis.hgetall(name)
output = {}
for k, v in r.items():
output[k.decode('utf8')] = pickle.loads(v)
return output
def all(self):
keys = self.config_redis.keys('*')
if keys:
result = {}
for key in keys:
if key != f'{self.name}:____':
v = self.get_backend(key)
result[key] = v
return result
else:
return {}
def validate_backend(self, backend_url: str):
"""
Returns the backend URL that was given, or a new one if that was offline.
:param backend_url:
:return:
"""
backend_info = self.get_backend(backend_url)
if not backend_info['online']:
old = backend_url
backend_url = get_a_cluster_backend()
print(f'Backend {old} offline. Request was redirected to {backend_url}')
return backend_url
cluster_config = RedisClusterStore('cluster_config')
def get_backends():
backends = cluster_config.all()
result = {}
for k, v in backends.items():
b = cluster_config.get_backend(k)
status = b.get('online', False)
priority = b['priority']
result[k] = {'status': status, 'priority': priority}
try:
if not opts.prioritize_by_size:
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: -kv[1]['priority'],
reverse=True
)
else:
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: estimate_model_size(kv[1]['model_config']),
reverse=True
)
offline_backends = sorted(
((url, info) for url, info in backends.items() if not info['online']),
key=lambda kv: -kv[1]['priority'],
reverse=True
)
return [url for url, info in online_backends], [url for url, info in offline_backends]
except KeyError:
traceback.print_exc()
print(backends)
def get_a_cluster_backend(model=None):
"""
Get a backend from Redis. If there are no online backends, return None.
If `model` is not supplied, we will pick one ourself.
"""
if model:
# First, determine if there are multiple backends hosting the same model.
backends_hosting_model = [i.decode('utf-8') for i in redis_running_models.smembers(model)]
# If so, create an iterator for those backends
if len(backends_hosting_model):
add_backend_cycler(model, backends_hosting_model)
cycled = redis_cycle(model)
if len(cycled):
return cycled[0]
else:
# No backend hosting that model
return None
else:
online, _ = get_backends()
if len(online):
return online[0]

View File

@ -0,0 +1,39 @@
import redis
redis_cycler_db = redis.Redis(host='localhost', port=6379, db=9)
def redis_cycle(list_name):
"""
Emulates itertools.cycle() but returns the complete shuffled list.
:param list_name:
:return:
"""
pipeline = redis_cycler_db.pipeline()
pipeline.lpop(list_name)
to_move = pipeline.execute()[0]
if not to_move:
return []
pipeline.rpush(list_name, to_move)
pipeline.lrange(list_name, 0, -1)
results = pipeline.execute()
new_list = results[-1]
return [x.decode('utf-8') for x in new_list]
def add_backend_cycler(list_name: str, new_elements: list):
existing_elements = [i.decode('utf-8') for i in redis_cycler_db.lrange(list_name, 0, -1)]
existing_set = set(existing_elements)
with redis_cycler_db.pipeline() as pipe:
# Add elements
for element in new_elements:
if element not in existing_set:
pipe.rpush(list_name, element)
# Remove elements
for element in existing_set:
if element not in new_elements:
pipe.lrem(list_name, 0, element)
pipe.execute()

View File

@ -0,0 +1,3 @@
from llm_server.custom_redis import RedisCustom
redis_running_models = RedisCustom('running_models')

View File

@ -0,0 +1,38 @@
import time
from threading import Thread
from llm_server.cluster.backend import test_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.stores import redis_running_models
def cluster_worker():
counter = 0
while True:
test_prompt = False
if counter % 4 == 0:
# Only send a test prompt every 120 seconds.
test_prompt = True
threads = []
for n, v in cluster_config.all().items():
thread = Thread(target=check_backend, args=(n, v, test_prompt))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
time.sleep(15)
counter += 1
def check_backend(n, v, test_prompt):
online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
if online:
running_model = backend_info['model']
for k, v in backend_info.items():
cluster_config.set_backend_value(n, k, v)
redis_running_models.sadd(running_model, n)
else:
for model in redis_running_models.keys():
redis_running_models.srem(model, n)
cluster_config.set_backend_value(n, 'online', online)

View File

@ -28,16 +28,19 @@ config_default_vars = {
'openai_force_no_hashes': True, 'openai_force_no_hashes': True,
'include_system_tokens_in_stats': True, 'include_system_tokens_in_stats': True,
'openai_moderation_scan_last_n': 5, 'openai_moderation_scan_last_n': 5,
'openai_moderation_workers': 10,
'openai_org_name': 'OpenAI', 'openai_org_name': 'OpenAI',
'openai_silent_trim': False, 'openai_silent_trim': False,
'openai_moderation_enabled': True, 'openai_moderation_enabled': True,
'netdata_root': None 'netdata_root': None,
'show_backends': True,
'background_homepage_cacher': True,
'openai_moderation_timeout': 5,
'prioritize_by_size': False
} }
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name']
mode_ui_names = { mode_ui_names = {
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), 'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'), 'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
} }

View File

@ -3,38 +3,28 @@ import sys
import openai import openai
import llm_server
from llm_server import opts from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
from llm_server.custom_redis import redis
from llm_server.database.conn import database from llm_server.database.conn import database
from llm_server.database.database import get_number_of_rows from llm_server.database.database import get_number_of_rows
from llm_server.helpers import resolve_path from llm_server.routes.queue import PriorityQueue
from llm_server.routes.cache import redis
def load_config(config_path, script_path): def load_config(config_path):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars) config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
success, config, msg = config_loader.load_config() success, config, msg = config_loader.load_config()
if not success: if not success:
return success, config, msg return success, config, msg
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
# TODO: this is atrocious # TODO: this is atrocious
opts.mode = config['mode']
opts.auth_required = config['auth_required'] opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts'] opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client'] opts.frontend_api_client = config['frontend_api_client']
opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts'] opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime'] opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/') opts.cluster = config['cluster']
opts.show_total_output_tokens = config['show_total_output_tokens'] opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root'] opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip'] opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
@ -53,10 +43,20 @@ def load_config(config_path, script_path):
opts.openai_force_no_hashes = config['openai_force_no_hashes'] opts.openai_force_no_hashes = config['openai_force_no_hashes']
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats'] opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n'] opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
opts.openai_moderation_workers = config['openai_moderation_workers']
opts.openai_org_name = config['openai_org_name'] opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim'] opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled'] opts.openai_moderation_enabled = config['openai_moderation_enabled']
opts.show_backends = config['show_backends']
opts.background_homepage_cacher = config['background_homepage_cacher']
opts.openai_moderation_timeout = config['openai_moderation_timeout']
opts.frontend_api_mode = config['frontend_api_mode']
opts.prioritize_by_size = config['prioritize_by_size']
# Scale the number of workers.
for item in config['cluster']:
opts.cluster_workers += item['concurrent_gens']
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
if opts.openai_expose_our_model and not opts.openai_api_key: if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
@ -78,6 +78,16 @@ def load_config(config_path, script_path):
if config['load_num_prompts']: if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts')) redis.set('proompts', get_number_of_rows('prompts'))
redis.set('backend_mode', opts.mode)
return success, config, msg return success, config, msg
def parse_backends(config):
if not config.get('cluster'):
return False
cluster = config.get('cluster')
config = {}
for item in cluster:
backend_url = item['backend_url'].strip('/')
item['backend_url'] = backend_url
config[backend_url] = item
return config

View File

@ -0,0 +1,3 @@
from llm_server.custom_redis import RedisCustom
redis_config = RedisCustom('redis_config')

View File

@ -1,24 +1,27 @@
import pickle
import sys import sys
import traceback import traceback
from typing import Callable, List, Mapping, Union from typing import Callable, List, Mapping, Optional, Union
import redis as redis_pkg import redis as redis_pkg
import simplejson as json import simplejson as json
from flask_caching import Cache from flask_caching import Cache
from redis import Redis from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000 ONE_MONTH_SECONDS = 2678000
class RedisWrapper: class RedisCustom(Redis):
""" """
A wrapper class to set prefixes to keys. A simple wrapper class for Redis to create a "namespace" within a DB,
which simplyifies key management.
""" """
def __init__(self, prefix, **kwargs): def __init__(self, prefix, **kwargs):
super().__init__()
self.redis = Redis(**kwargs) self.redis = Redis(**kwargs)
self.prefix = prefix self.prefix = prefix
try: try:
@ -34,12 +37,11 @@ class RedisWrapper:
def set(self, key, value, ex: Union[ExpiryT, None] = None): def set(self, key, value, ex: Union[ExpiryT, None] = None):
return self.redis.set(self._key(key), value, ex=ex) return self.redis.set(self._key(key), value, ex=ex)
def get(self, key, dtype=None, default=None): def get(self, key, default=None, dtype=None):
""" # TODO: use pickle
:param key: import inspect
:param dtype: convert to this type if inspect.isclass(default):
:return: raise Exception
"""
d = self.redis.get(self._key(key)) d = self.redis.get(self._key(key))
if dtype and d: if dtype and d:
@ -108,7 +110,10 @@ class RedisWrapper:
): ):
return self.redis.hincrby(self._key(name), key, amount) return self.redis.hincrby(self._key(name), key, amount)
def hdel(self, name: str, *keys: List): def zcard(self, name: KeyT):
return self.redis.zcard(self._key(name))
def hdel(self, name: str, *keys: str):
return self.redis.hdel(self._key(name), *keys) return self.redis.hdel(self._key(name), *keys)
def hget( def hget(
@ -129,9 +134,62 @@ class RedisWrapper:
): ):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt) return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
def lpush(self, name: str, *values: FieldT):
return self.redis.lpush(self._key(name), *values)
def hset(
self,
name: str,
key: Optional = None,
value=None,
mapping: Optional[dict] = None,
items: Optional[list] = None,
):
return self.redis.hset(self._key(name), key, value, mapping, items)
def hkeys(self, name: str): def hkeys(self, name: str):
return self.redis.hkeys(self._key(name)) return self.redis.hkeys(self._key(name))
def hmget(self, name: str, keys: List, *args: List):
return self.redis.hmget(self._key(name), keys, *args)
def hgetall(self, name: str):
return self.redis.hgetall(self._key(name))
def keys(self, pattern: PatternT = "*", **kwargs):
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
keys = []
for key in raw_keys:
p = key.decode('utf-8').split(':')
if len(p) >= 2:
# Delete prefix
del p[0]
k = ':'.join(p)
if k != '____':
keys.append(k)
return keys
def pipeline(self, transaction=True, shard_hint=None):
return self.redis.pipeline(transaction, shard_hint)
def smembers(self, name: str):
return self.redis.smembers(self._key(name))
def spop(self, name: str, count: Optional[int] = None):
return self.redis.spop(self._key(name), count)
def rpoplpush(self, src, dst):
return self.redis.rpoplpush(src, dst)
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
return self.redis.zpopmin(self._key(name), count)
def exists(self, *names: KeyT):
n = []
for name in names:
n.append(self._key(name))
return self.redis.exists(*n)
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex) return self.set(key, json.dumps(dict_value), ex=ex)
@ -142,6 +200,15 @@ class RedisWrapper:
else: else:
return json.loads(r.decode("utf-8")) return json.loads(r.decode("utf-8"))
def setp(self, name, value):
self.redis.set(self._key(name), pickle.dumps(value))
def getp(self, name: str):
r = self.redis.get(self._key(name))
if r:
return pickle.loads(r)
return r
def flush(self): def flush(self):
flushed = [] flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'): for key in self.redis.scan_iter(f'{self.prefix}:*'):
@ -149,5 +216,40 @@ class RedisWrapper:
self.redis.delete(key) self.redis.delete(key)
return flushed return flushed
def flushall(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
redis = RedisWrapper('local_llm') def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
def lrange(self, name: str, start: int, end: int):
return self.redis.lrange(self._key(name), start, end)
def delete(self, *names: KeyT):
return self.redis.delete(*[self._key(i) for i in names])
def lpop(self, name: str, count: Optional[int] = None):
return self.redis.lpop(self._key(name), count)
def zrange(
self,
name: KeyT,
start: int,
end: int,
desc: bool = False,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
byscore: bool = False,
bylex: bool = False,
offset: int = None,
num: int = None,
):
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
def zrem(self, name: KeyT, *values: FieldT):
return self.redis.zrem(self._key(name), *values)
redis = RedisCustom('local_llm')

View File

@ -5,20 +5,20 @@ class DatabaseConnection:
host: str = None host: str = None
username: str = None username: str = None
password: str = None password: str = None
database: str = None database_name: str = None
def init_db(self, host, username, password, database): def init_db(self, host, username, password, database_name):
self.host = host self.host = host
self.username = username self.username = username
self.password = password self.password = password
self.database = database self.database_name = database_name
def cursor(self): def cursor(self):
db = pymysql.connect( db = pymysql.connect(
host=self.host, host=self.host,
user=self.username, user=self.username,
password=self.password, password=self.password,
database=self.database, database=self.database_name,
charset='utf8mb4', charset='utf8mb4',
autocommit=True, autocommit=True,
) )

View File

@ -1,15 +1,19 @@
import json import json
import time import time
import traceback import traceback
from typing import Union
import llm_server
from llm_server import opts from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.database.conn import database from llm_server.database.conn import database
from llm_server.llm.vllm import tokenize from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
# Try not to shove JSON into the database.
if isinstance(response, dict) and response.get('results'): if isinstance(response, dict) and response.get('results'):
response = response['results'][0]['text'] response = response['results'][0]['text']
try: try:
@ -19,10 +23,11 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
except: except:
pass pass
prompt_tokens = llm_server.llm.get_token_count(prompt) prompt_tokens = get_token_count(prompt, backend_url)
if not is_error: if not is_error:
if not response_tokens: if not response_tokens:
response_tokens = llm_server.llm.get_token_count(response) response_tokens = get_token_count(response, backend_url)
else: else:
response_tokens = None response_tokens = None
@ -43,7 +48,9 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
if token: if token:
increment_token_uses(token) increment_token_uses(token)
running_model = redis.get('running_model', str, 'ERROR') backend_info = cluster_config.get_backend(backend_url)
running_model = backend_info.get('model')
backend_mode = backend_info['mode']
timestamp = int(time.time()) timestamp = int(time.time())
cursor = database.cursor() cursor = database.cursor()
try: try:
@ -52,7 +59,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", """,
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) (ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally: finally:
cursor.close() cursor.close()
@ -179,3 +186,21 @@ def increment_token_uses(token):
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,)) cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
finally: finally:
cursor.close() cursor.close()
def get_token_ratelimit(token):
priority = 9990
simultaneous_ip = opts.simultaneous_requests_per_ip
if token:
cursor = database.cursor()
try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
result = cursor.fetchone()
if result:
priority, simultaneous_ip = result
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip

View File

@ -0,0 +1,30 @@
import pickle
from typing import Union
from redis import Redis
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
r = Redis(host='localhost', port=6379, db=3)
data = {
'function': 'log_prompt',
'args': [],
'kwargs': {
'ip': ip,
'token': token,
'prompt': prompt,
'response': response,
'gen_time': gen_time,
'parameters': parameters,
'headers': dict(headers) if headers else headers,
'backend_response_code': backend_response_code,
'request_url': request_url,
'backend_url': backend_url,
'response_tokens': response_tokens,
'is_error': is_error
}
}
r.publish('database-logger', pickle.dumps(data))

View File

@ -8,7 +8,7 @@ import simplejson as json
from flask import make_response from flask import make_response
from llm_server import opts from llm_server import opts
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
def resolve_path(*p: str): def resolve_path(*p: str):
@ -54,13 +54,14 @@ def jsonify_pretty(json_dict: Union[list, dict], status=200, indent=4, sort_keys
def round_up_base(n, base): def round_up_base(n, base):
if base == 0: if base == 0:
print('round_up_base DIVIDE BY ZERO ERROR????', n, base) # TODO: I don't think passing (0, 0) to this function is a sign of any underlying issues.
# print('round_up_base DIVIDE BY ZERO ERROR????', n, base)
return 0 return 0
return math.ceil(n / base) * base return math.ceil(n / base) * base
def auto_set_base_client_api(request): def auto_set_base_client_api(request):
http_host = redis.get('http_host', str) http_host = redis.get('http_host', dtype=str)
host = request.headers.get("Host") host = request.headers.get("Host")
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host): if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
# If the current http_host is not an IP, don't do anything. # If the current http_host is not an IP, don't do anything.

View File

@ -1,12 +0,0 @@
import threading
class ThreadSafeInteger:
def __init__(self, value=0):
self.value = value
self._value_lock = threading.Lock()
def increment(self):
with self._value_lock:
self.value += 1
return self.value

View File

@ -1,11 +1,30 @@
import tiktoken
from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import oobabooga, vllm from llm_server.llm import oobabooga, vllm
from llm_server.routes.cache import redis from llm_server.logging import create_logger
def get_token_count(prompt: str): def fallback_tokenizer(prompt: str):
backend_mode = redis.get('backend_mode', str) tokenizer = tiktoken.get_encoding("cl100k_base")
return len(tokenizer.encode(prompt)) + 10
def get_token_count(prompt: str, backend_url: str):
backend_url = cluster_config.validate_backend(backend_url)
if not backend_url:
logger = create_logger('tokenizer')
logger.warning('using fallback tokenizer as there is no valid backend')
return fallback_tokenizer(prompt)
backend_mode = cluster_config.get_backend(backend_url).get('mode')
if not backend_mode:
logger = create_logger('tokenizer')
logger.warning("using fallback tokenizer as the backend isn't initalized")
return fallback_tokenizer(prompt)
if backend_mode == 'vllm': if backend_mode == 'vllm':
return vllm.tokenize(prompt) return vllm.tokenize(prompt, backend_url)
elif backend_mode == 'ooba': elif backend_mode == 'ooba':
return oobabooga.tokenize(prompt) return oobabooga.tokenize(prompt)
else: else:

View File

@ -1,14 +1,15 @@
from llm_server import opts from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
def generator(request_json_body): def generator(request_json_body, cluster_backend, timeout: int = None):
if opts.mode == 'oobabooga': mode = cluster_config.get_backend(cluster_backend)['mode']
if mode == 'ooba':
# from .oobabooga.generate import generate # from .oobabooga.generate import generate
# return generate(request_json_body) # return generate(request_json_body)
raise NotImplementedError raise NotImplementedError
elif opts.mode == 'vllm': elif mode == 'vllm':
from .vllm.generate import generate from .vllm.generate import generate
r = generate(request_json_body) return generate(request_json_body, cluster_backend, timeout=timeout)
return r
else: else:
raise Exception raise Exception

View File

@ -3,23 +3,35 @@ import requests
from llm_server import opts from llm_server import opts
def get_running_model(): def get_running_model(backend_url: str, mode: str):
# TODO: cache the results for 1 min so we don't have to keep calling the backend if mode == 'ooba':
# TODO: only use one try/catch
if opts.mode == 'oobabooga':
try: try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json() r_json = backend_response.json()
return r_json['result'], None return r_json['result'], None
except Exception as e: except Exception as e:
return False, e return False, e
elif opts.mode == 'vllm': elif mode == 'vllm':
try: try:
backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json() r_json = backend_response.json()
return r_json['model'], None return r_json['model'], None
except Exception as e: except Exception as e:
return False, e return False, e
else: else:
raise Exception raise Exception
def get_info(backend_url: str, mode: str):
if mode == 'ooba':
return {}
# raise NotImplementedError
elif mode == 'vllm':
try:
r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout)
j = r.json()
except Exception as e:
return {}
return j
else:
raise Exception

View File

@ -2,14 +2,17 @@ from typing import Tuple, Union
import flask import flask
from llm_server import opts from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
class LLMBackend: class LLMBackend:
_default_params: dict _default_params: dict
def __init__(self, backend_url: str):
self.backend_url = backend_url
self.backend_info = cluster_config.get_backend(self.backend_url)
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers): def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError raise NotImplementedError
@ -32,14 +35,16 @@ class LLMBackend:
""" """
If a backend needs to do other checks not related to the prompt or parameters. If a backend needs to do other checks not related to the prompt or parameters.
Default is no extra checks preformed. Default is no extra checks preformed.
:param request:
:param prompt:
:param parameters: :param parameters:
:return: :return:
""" """
return True, None return True, None
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = get_token_count(prompt) prompt_len = get_token_count(prompt, self.backend_url)
if prompt_len > opts.context_size - 10: token_limit = self.backend_info['model_config']['max_position_embeddings']
model_name = redis.get('running_model', str, 'NO MODEL ERROR') if prompt_len > token_limit - 10:
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size' return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size'
return True, None return True, None

View File

@ -1,78 +1,6 @@
from flask import jsonify
from ..llm_backend import LLMBackend from ..llm_backend import LLMBackend
from ...database.database import log_prompt
from ...helpers import safe_list_get
from ...routes.cache import redis
from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json
class OobaboogaBackend(LLMBackend): class OobaboogaBackend(LLMBackend):
default_params = {} def __int__(self):
return
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError('need to implement default_params')
backend_err = False
response_valid_json, response_json_body = validate_json(response)
if response:
try:
# Be extra careful when getting attributes from the response object
response_status_code = response.status_code
except:
response_status_code = 0
else:
response_status_code = None
# ===============================================
# We encountered an error
if not success or not response or error_msg:
if not error_msg or error_msg == '':
error_msg = 'Unknown error.'
else:
error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, 'error')
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': error_msg,
'results': [{'text': backend_response}]
}), 400
# ===============================================
if response_valid_json:
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
# Ooba doesn't return any error messages so we will just tell the client an error occurred
backend_err = True
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
'error')
response_json_body['results'][0]['text'] = backend_response
if not backend_err:
redis.incr('proompts')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({
**response_json_body
}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 400
def validate_params(self, params_dict: dict):
# No validation required
return True, None
def get_parameters(self, parameters):
del parameters['prompt']
return parameters

View File

@ -10,7 +10,7 @@ def check_moderation_endpoint(prompt: str):
} }
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10) response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
if response.status_code != 200: if response.status_code != 200:
print(response.text) print('moderation failed:', response)
response.raise_for_status() response.raise_for_status()
response = response.json() response = response.json()

View File

@ -0,0 +1,97 @@
from flask import jsonify
from llm_server import opts
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
if not request_json_body.get('stop'):
request_json_body['stop'] = []
if not isinstance(request_json_body['stop'], list):
# It is a string, so create a list with the existing element.
request_json_body['stop'] = [request_json_body['stop']]
if stop_hashes:
if opts.openai_force_no_hashes:
request_json_body['stop'].append('###')
else:
# TODO: make stopping strings a configurable
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT'])
else:
request_json_body['stop'].extend(['user:', 'assistant:'])
if request_json_body.get('frequency_penalty', 0) < -2:
request_json_body['frequency_penalty'] = -2
elif request_json_body.get('frequency_penalty', 0) > 2:
request_json_body['frequency_penalty'] = 2
if mode == 'vllm' and request_json_body.get('top_p') == 0:
request_json_body['top_p'] = 0.01
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
if request_json_body['max_tokens'] == 0:
# We don't want to set any defaults here.
del request_json_body['max_tokens']
return request_json_body
def format_oai_err(err_msg):
print('OAI ERROR MESSAGE:', err_msg)
return jsonify({
"error": {
"message": err_msg,
"type": "invalid_request_error",
"param": None,
"code": None
}
}), 400
def validate_oai(parameters):
if parameters.get('messages'):
for m in parameters['messages']:
if m['role'].lower() not in ['assistant', 'user', 'system']:
return format_oai_err('messages role must be assistant, user, or system')
if parameters.get('temperature', 0) > 2:
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
if parameters.get('temperature', 0) < 0:
return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('presence_penalty', 1) > 2:
return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
if parameters.get('presence_penalty', 1) < -2:
return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('max_tokens', 2) < 1:
return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'")
def return_invalid_model_err(requested_model: str):
if requested_model:
msg = f"The model `{requested_model}` does not exist"
else:
msg = "The requested model does not exist"
return jsonify({
"error": {
"message": msg,
"type": "invalid_request_error",
"param": None,
"code": "model_not_found"
}
}), 404

View File

@ -2,86 +2,35 @@ import concurrent.futures
import re import re
import secrets import secrets
import string import string
import time
import traceback import traceback
from typing import Dict, List from typing import Dict, List
import tiktoken import tiktoken
from flask import jsonify, make_response
import llm_server
from llm_server import opts from llm_server import opts
from llm_server.llm import get_token_count from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line. ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line. ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
def build_openai_response(prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
# y = response.split('\n### ')
# if len(y) > 1:
# response = re.sub(r'\n$', '', y[0].strip(' '))
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR')
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
def generate_oai_string(length=24): def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length)) return ''.join(secrets.choice(alphabet) for i in range(length))
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]: def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
tokenizer = tiktoken.get_encoding("cl100k_base") def get_token_count_thread(msg):
return get_token_count(msg["content"], backend_url)
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg["content"]))
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt)) token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts) total_tokens = sum(token_counts)
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
# If total tokens exceed the limit, start trimming # If total tokens exceed the limit, start trimming
if total_tokens > context_token_limit: if total_tokens + formatting_tokens > context_token_limit:
while True: while True:
while total_tokens + formatting_tokens > context_token_limit: while total_tokens + formatting_tokens > context_token_limit:
# Calculate the index to start removing messages from # Calculate the index to start removing messages from
@ -94,22 +43,43 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt): if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
break break
def get_token_count_thread(msg):
return get_token_count(msg["content"])
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt)) token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts) total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
if total_tokens + formatting_tokens > context_token_limit: if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend # Start over, but this time calculate the token count using the backend
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt)) token_counts = list(executor.map(get_token_count_thread, prompt))
else: else:
break break
return prompt
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
tokenizer = tiktoken.get_encoding("cl100k_base")
token_count = get_token_count(prompt, backend_url)
# If total tokens exceed the limit, start trimming
if token_count > context_token_limit:
while True:
while token_count > context_token_limit:
# Calculate the index to start removing characters from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
token_count = len(tokenizer.encode(prompt))
if token_count <= context_token_limit or remove_index == len(prompt):
break
token_count = get_token_count(prompt, backend_url)
if token_count > context_token_limit:
# Start over, but this time calculate the token count using the backend
token_count = get_token_count(prompt, backend_url)
else:
break
return prompt return prompt
@ -117,8 +87,9 @@ def transform_messages_to_prompt(oai_messages):
try: try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}' prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in oai_messages: for msg in oai_messages:
if not msg.get('content') or not msg.get('role'): if 'content' not in msg.keys() or 'role' not in msg.keys():
return False return False
msg['content'] = str(msg['content']) # Prevent any weird issues.
if msg['role'] == 'system': if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n' prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user': elif msg['role'] == 'user':
@ -126,7 +97,7 @@ def transform_messages_to_prompt(oai_messages):
elif msg['role'] == 'assistant': elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n' prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else: else:
return False raise Exception(f'Unknown role: {msg["role"]}')
except Exception as e: except Exception as e:
# TODO: use logging # TODO: use logging
traceback.print_exc() traceback.print_exc()

View File

@ -1,80 +1,21 @@
""" """
This file is used by the worker that processes requests. This file is used by the worker that processes requests.
""" """
import json
import time
from uuid import uuid4
import requests import requests
import llm_server
from llm_server import opts from llm_server import opts
from llm_server.routes.cache import redis
# TODO: make the VLMM backend return TPS and time elapsed # TODO: make the VLMM backend return TPS and time elapsed
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py # https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
def prepare_json(json_data: dict): def prepare_json(json_data: dict):
# logit_bias is not currently supported
# del json_data['logit_bias']
# Convert back to VLLM. # Convert back to VLLM.
json_data['max_tokens'] = json_data.pop('max_new_tokens') json_data['max_tokens'] = json_data.pop('max_new_tokens')
return json_data return json_data
def transform_to_text(json_request, api_response):
"""
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
:param json_request:
:param api_response:
:return:
"""
prompt = transform_prompt_to_text(json_request['messages'])
text = ''
finish_reason = None
for line in api_response.split('\n'):
if line.startswith('data:'):
try:
data = json.loads(line[5:].strip())
except json.decoder.JSONDecodeError:
break
if 'choices' in data:
for choice in data['choices']:
if 'delta' in choice and 'content' in choice['delta']:
text += choice['delta']['content']
if data['choices'][0]['finish_reason']:
finish_reason = data['choices'][0]['finish_reason']
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text))
running_model = redis.get('running_model', str, 'ERROR')
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
return {
"id": str(uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": running_model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
},
"choices": [
{
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason,
"index": 0
}
]
}
def transform_prompt_to_text(prompt: list): def transform_prompt_to_text(prompt: list):
text = '' text = ''
for item in prompt: for item in prompt:
@ -82,26 +23,26 @@ def transform_prompt_to_text(prompt: list):
return text.strip('\n') return text.strip('\n')
def handle_blocking_request(json_data: dict): def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10):
try: try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
except requests.exceptions.ReadTimeout: except requests.exceptions.ReadTimeout:
print(f'Failed to reach VLLM inference endpoint - request to backend timed out') # print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
return False, None, 'Request to backend timed out' return False, None, 'Request to backend timed out'
except Exception as e: except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') # print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False, None, 'Request to backend encountered error' return False, None, 'Request to backend encountered error'
if r.status_code != 200: if r.status_code != 200:
print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}') # print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
return False, r, f'Backend returned {r.status_code}' return False, r, f'Backend returned {r.status_code}'
return True, r, None return True, r, None
def generate(json_data: dict): def generate(json_data: dict, cluster_backend, timeout: int = None):
if json_data.get('stream'): if json_data.get('stream'):
try: try:
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
except Exception as e: except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') return False
else: else:
return handle_blocking_request(json_data) return handle_blocking_request(json_data, cluster_backend, timeout=timeout)

View File

@ -1,3 +1,7 @@
import requests
from llm_server import opts
vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/vllm-project/vllm" target="_blank">vllm</a> and not all Oobabooga parameters are supported.</p> vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/vllm-project/vllm" target="_blank">vllm</a> and not all Oobabooga parameters are supported.</p>
<strong>Supported Parameters:</strong> <strong>Supported Parameters:</strong>
<ul> <ul>
@ -7,4 +11,4 @@ vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="
<li><kbd>max_new_tokens</kbd></li> <li><kbd>max_new_tokens</kbd></li>
<li><kbd>num_beams</kbd> <span style="font-size:9pt">(setting to greater than 1 enables beam search)</span></li> <li><kbd>num_beams</kbd> <span style="font-size:9pt">(setting to greater than 1 enables beam search)</span></li>
<li><kbd>ban_eos_token</kbd></li> <li><kbd>ban_eos_token</kbd></li>
</ul>""" </ul>"""

View File

@ -1,26 +1,51 @@
import concurrent.futures
import requests import requests
import tiktoken import tiktoken
from llm_server import opts from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.logging import create_logger
def tokenize(prompt: str) -> int: def tokenize(prompt: str, backend_url: str) -> int:
assert backend_url
assert isinstance(backend_url, str)
if not prompt: if not prompt:
# The tokenizers have issues when the prompt is None. # The tokenizers have issues when the prompt is None.
return 0 return 0
assert isinstance(prompt, str)
logger = create_logger('tokenizer')
# The backend could have died between when the request was
# submitted and now, so let's double check it's still online.
backend_url = cluster_config.validate_backend(backend_url)
tokenizer = tiktoken.get_encoding("cl100k_base") tokenizer = tiktoken.get_encoding("cl100k_base")
# First we tokenize it locally to determine if it's worth sending it to the backend. # Split the prompt into 2000 character chunks
initial_estimate = len(tokenizer.encode(prompt)) chunk_size = 2000
if initial_estimate <= opts.context_size + 200: chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)]
# Define a function to send a chunk to the server
def send_chunk(chunk):
try: try:
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) r = requests.post(f'{backend_url}/tokenize', json={'input': chunk}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
j = r.json() j = r.json()
return j['length'] return j['length']
except Exception as e: except Exception as e:
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') logger.debug(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
return len(tokenizer.encode(prompt)) + 10 return len(tokenizer.encode(chunk)) + 10
else:
# If the result was greater than our context size, return the estimate. # Use a ThreadPoolExecutor to send all chunks to the server at once
# We won't be sending it through the backend so it does't need to be accurage. with concurrent.futures.ThreadPoolExecutor() as executor:
return initial_estimate future_to_chunk = {executor.submit(send_chunk, chunk): chunk for chunk in chunks}
for future in concurrent.futures.as_completed(future_to_chunk):
chunk = future_to_chunk[future]
try:
data = future.result()
except Exception as exc:
logger.warning('%r generated an exception: %s' % (chunk, exc))
return sum(future.result() for future in future_to_chunk)

View File

@ -1,10 +1,9 @@
import threading
from typing import Tuple from typing import Tuple
from flask import jsonify from flask import jsonify
from vllm import SamplingParams from vllm import SamplingParams
from llm_server.database.database import log_prompt from llm_server.database.log_to_db import log_to_db
from llm_server.llm.llm_backend import LLMBackend from llm_server.llm.llm_backend import LLMBackend
@ -19,16 +18,8 @@ class VLLMBackend(LLMBackend):
# Failsafe # Failsafe
backend_response = '' backend_response = ''
r_url = request.url log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
def background_task():
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
return jsonify({'results': [{'text': backend_response}]}), 200 return jsonify({'results': [{'text': backend_response}]}), 200
@ -38,14 +29,20 @@ class VLLMBackend(LLMBackend):
top_k = parameters.get('top_k', self._default_params['top_k']) top_k = parameters.get('top_k', self._default_params['top_k'])
if top_k <= 0: if top_k <= 0:
top_k = -1 top_k = -1
# TODO: support more params
sampling_params = SamplingParams( sampling_params = SamplingParams(
temperature=parameters.get('temperature', self._default_params['temperature']), temperature=parameters.get('temperature', self._default_params['temperature']),
top_p=parameters.get('top_p', self._default_params['top_p']), top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=top_k, top_k=top_k,
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False, use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=parameters.get('stopping_strings', self._default_params['stop']), stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))),
ignore_eos=parameters.get('ban_eos_token', False), ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens']) max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']),
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']),
early_stopping=parameters.get('early_stopping', self._default_params['early_stopping'])
) )
except ValueError as e: except ValueError as e:
return None, str(e).strip('.') return None, str(e).strip('.')

52
llm_server/logging.py Normal file
View File

@ -0,0 +1,52 @@
import logging
import coloredlogs
from llm_server import opts
class LoggingInfo:
def __init__(self):
self._level = logging.INFO
self._format = opts.LOGGING_FORMAT
@property
def level(self):
return self._level
@level.setter
def level(self, value):
self._level = value
@property
def format(self):
return self._format
@format.setter
def format(self, value):
self._format = value
logging_info = LoggingInfo()
def init_logging():
"""
Set up the parent logger.
:return:
"""
logger = logging.getLogger('llm_server')
logger.setLevel(logging_info.level)
def create_logger(name):
logger = logging.getLogger('llm_server').getChild(name)
logger.setLevel(logging_info.level)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging_info.level)
formatter = logging.Formatter(logging_info.format)
handler.setFormatter(formatter)
logger.addHandler(handler)
coloredlogs.install(logger=logger, level=logging_info.level)
return logger

1
llm_server/messages.py Normal file
View File

@ -0,0 +1 @@
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'

View File

@ -1,52 +0,0 @@
import json
from datetime import datetime, timedelta
import requests
from llm_server import opts
def get_power_states():
gpu_num = 0
output = {}
while True:
url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state"
try:
response = requests.get(url, timeout=10)
if response.status_code != 200:
break
data = json.loads(response.text)
power_state_data = data['data'][0]
power_state = None
for i in range(1, len(power_state_data)):
if power_state_data[i] == 1:
power_state = data['labels'][i]
break
output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p'))
except Exception as e:
print('Failed to fetch Netdata metrics:', e)
return output
gpu_num += 1
return output
def get_gpu_wh(gpu_id: int):
chart_name = f"nvidia_smi.gpu{gpu_id}_power"
now = datetime.now()
one_hour_ago = now - timedelta(hours=1)
num_seconds = int((now - one_hour_ago).total_seconds())
params = {
"chart": chart_name,
"after": int(one_hour_ago.timestamp()),
"before": int(now.timestamp()),
"points": num_seconds,
"group": "second",
"format": "json",
"options": "absolute|jsonwrap"
}
response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10)
data = json.loads(response.text)
total_power_usage_watts = sum(point[1] for point in data['result']['data'])
# total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1)
total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3)
return total_power_usage_kwh

View File

@ -1,12 +1,11 @@
# Read-only global variables # Read-only global variables
# Uppercase variables are read-only globals.
# Lowercase variables are ones that are set on startup and are never changed.
# TODO: rewrite the config system so I don't have to add every single config default here # TODO: rewrite the config system so I don't have to add every single config default here
running_model = 'ERROR' frontend_api_mode = 'ooba'
concurrent_gens = 3
mode = 'oobabooga'
backend_url = None
context_size = 5555
max_new_tokens = 500 max_new_tokens = 500
auth_required = False auth_required = False
log_prompts = False log_prompts = False
@ -33,7 +32,15 @@ openai_expose_our_model = False
openai_force_no_hashes = True openai_force_no_hashes = True
include_system_tokens_in_stats = True include_system_tokens_in_stats = True
openai_moderation_scan_last_n = 5 openai_moderation_scan_last_n = 5
openai_moderation_workers = 10
openai_org_name = 'OpenAI' openai_org_name = 'OpenAI'
openai_silent_trim = False openai_silent_trim = False
openai_moderation_enabled = True openai_moderation_enabled = True
cluster = {}
show_backends = True
background_homepage_cacher = True
openai_moderation_timeout = 5
prioritize_by_size = False
cluster_workers = 0
redis_stream_timeout = 25000
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"

View File

@ -1,21 +1,9 @@
import sys import sys
from redis import Redis from llm_server.custom_redis import redis
from llm_server.routes.cache import redis
from llm_server.routes.v1.generate_stats import generate_stats
def server_startup(s): def server_startup(s):
if not redis.get('daemon_started', bool): if not redis.get('daemon_started', dtype=bool):
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
sys.exit(1) sys.exit(1)
# Flush the RedisPriorityQueue database.
queue_redis = Redis(host='localhost', port=6379, db=15)
for key in queue_redis.scan_iter('*'):
queue_redis.delete(key)
# Cache the initial stats
print('Loading backend stats...')
generate_stats()

View File

@ -1,11 +1,18 @@
from llm_server import opts from llm_server.cluster.cluster_config import cluster_config
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
def format_sillytavern_err(msg: str, level: str = 'info'): def format_sillytavern_err(msg: str, backend_url: str = None, error_type: str = 'info'):
http_host = redis.get('http_host', str) if backend_url:
cluster_backend_hash = cluster_config.get_backend(backend_url)['hash']
else:
cluster_backend_hash = 'none'
http_host = redis.get('http_host', dtype=str)
return f"""``` return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} === === MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
-> {level.upper()} <- -> {error_type.upper()} <-
{msg} {msg}
```
```
BACKEND: {cluster_backend_hash}
```""" ```"""

View File

@ -100,4 +100,4 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas
j = json.loads(str(data)) j = json.loads(str(data))
return True, j return True, j
except Exception as e: except Exception as e:
return False, e return False, e

View File

@ -0,0 +1,15 @@
def estimate_model_size(config: dict):
"""
Estimate the size of a model from its config. No idea if this is correct,
but it allows us to compare models.
:param config:
:return:
"""
vocab_size = config.get('vocab_size')
hidden_size = config.get('hidden_size')
num_hidden_layers = config.get('num_hidden_layers')
intermediate_size = config.get('intermediate_size')
if vocab_size and hidden_size and num_hidden_layers and intermediate_size:
total_params = (vocab_size * hidden_size) + (num_hidden_layers * ((hidden_size * intermediate_size * 4) + (hidden_size * hidden_size * 3)))
return int(total_params / 1e9)
return 0

View File

@ -3,8 +3,8 @@ from typing import Tuple
import flask import flask
from flask import jsonify, request from flask import jsonify, request
from llm_server import opts from llm_server import messages, opts
from llm_server.database.database import log_prompt from llm_server.database.log_to_db import log_to_db
from llm_server.routes.helpers.client import format_sillytavern_err from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
@ -13,8 +13,11 @@ class OobaRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def handle_request(self): def handle_request(self, return_ok: bool = True):
assert not self.used assert not self.used
if self.offline:
print('This backend is offline:', messages.BACKEND_OFFLINE)
return self.handle_error(messages.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request() request_valid, invalid_response = self.validate_request()
if not request_valid: if not request_valid:
@ -25,14 +28,19 @@ class OobaRequestHandler(RequestHandler):
llm_request = {**self.parameters, 'prompt': prompt} llm_request = {**self.parameters, 'prompt': prompt}
_, backend_response = self.generate_response(llm_request) _, backend_response = self.generate_response(llm_request)
return backend_response if return_ok:
# Always return 200 so ST displays our error messages
return backend_response[0], 200
else:
# The OpenAI route needs to detect 429 errors.
return backend_response
def handle_ratelimited(self, do_log: bool = True): def handle_ratelimited(self, do_log: bool = True):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg) backend_response = self.handle_error(msg)
if do_log: if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return backend_response[0], 200 # We only return the response from handle_error(), not the error code return backend_response[0], 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
@ -40,7 +48,7 @@ class OobaRequestHandler(RequestHandler):
# TODO: how to format this # TODO: how to format this
response_msg = error_msg response_msg = error_msg
else: else:
response_msg = format_sillytavern_err(error_msg, error_type) response_msg = format_sillytavern_err(error_msg, error_type=error_type, backend_url=self.backend_url)
return jsonify({ return jsonify({
'results': [{'text': response_msg}] 'results': [{'text': response_msg}]

View File

@ -5,9 +5,11 @@ from ..server_error import handle_server_error
from ... import opts from ... import opts
openai_bp = Blueprint('openai/v1/', __name__) openai_bp = Blueprint('openai/v1/', __name__)
openai_model_bp = Blueprint('openai/', __name__)
@openai_bp.before_request @openai_bp.before_request
@openai_model_bp.before_request
def before_oai_request(): def before_oai_request():
if not opts.enable_openi_compatible_backend: if not opts.enable_openi_compatible_backend:
return 'The OpenAI-compatible backend is disabled.', 401 return 'The OpenAI-compatible backend is disabled.', 401
@ -15,8 +17,22 @@ def before_oai_request():
@openai_bp.errorhandler(500) @openai_bp.errorhandler(500)
@openai_model_bp.errorhandler(500)
def handle_error(e): def handle_error(e):
return handle_server_error(e) """
Found Codes:
"auth_subrequest_error"
"""
print('OAI returning error:', e)
return jsonify({
"error": {
"message": "Internal server error",
"type": "auth_subrequest_error",
"param": None,
"code": "internal_error"
}
}), 500
from .models import openai_list_models from .models import openai_list_models

View File

@ -1,113 +1,175 @@
import json import json
import threading
import time import time
import traceback import traceback
import ujson
from flask import Response, jsonify, request from flask import Response, jsonify, request
from redis import Redis
from . import openai_bp from llm_server.custom_redis import redis
from ..cache import redis from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue
from ... import opts from ... import opts
from ...database.database import log_prompt from ...database.log_to_db import log_to_db
from ...llm.generator import generator from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from ...llm.vllm import tokenize
# TODO: add rate-limit headers? # TODO: add rate-limit headers?
@openai_bp.route('/chat/completions', methods=['POST']) @openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions(): @openai_model_bp.route('/<model_name>/v1/chat/completions', methods=['POST'])
def openai_chat_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'): if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else: else:
handler = OpenAIRequestHandler(request, request_json_body) handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if request_json_body.get('stream'): if handler.offline:
if not opts.enable_streaming: return return_invalid_model_err(model_name)
# TODO: return a proper OAI error message
return 'disabled', 401
if opts.mode != 'vllm': if not request_json_body.get('stream'):
# TODO: implement other backends
raise NotImplementedError
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception('TODO: simulate OAI here')
else:
handler.prompt = transform_messages_to_prompt(request_json_body['messages'])
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
try:
response = generator(msg_to_backend)
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
data = {
"id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
return Response(generate(), mimetype='text/event-stream')
except:
# TODO: simulate OAI here
raise Exception
else:
try: try:
return handler.handle_request() return handler.handle_request()
except Exception: except Exception:
traceback.print_exc() traceback.print_exc()
return 'Internal server error', 500 return 'Internal server error', 500
else:
if not opts.enable_streaming:
return 'Streaming disabled', 403
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
handler.parameters, e = handler.get_parameters()
handler.request_json_body = {
'messages': handler.request_json_body['messages'],
'model': handler.request_json_body['model'],
**handler.parameters
}
if opts.openai_silent_trim:
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
else:
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
if not handler.prompt:
# Prevent issues on the backend.
return 'Invalid prompt', 400
# Need to set the prompt in the JSON body since that's what the inference worker expects.
handler.request_json_body['prompt'] = handler.prompt
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
429,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
try:
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
# Need to do this before we enter generate() since we want to be able to
# return a 408 if necessary.
_, stream_name, error_msg = event.wait()
if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None # set to null so that the Finally ignores it.
return 'Request Timeout', 408
def generate():
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0'
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data'])
if data['error']:
# Not printing error since we can just check the daemon log.
print('OAI streaming encountered error')
yield 'data: [DONE]\n\n'
return
elif data['new']:
response = {
"id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": data['new']
},
"finish_reason": None
}
]
}
generated_text = generated_text + data['new']
yield f'data: {json.dumps(response)}\n\n'
elif data['completed']:
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
200,
r_url,
handler.backend_url,
)
return
except GeneratorExit:
return
except Exception:
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
if event:
redis.publish(f'notifications:{event.event_id}', 'canceled')
if stream_name:
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')
except Exception:
traceback.print_exc()
return 'INTERNAL SERVER', 500

View File

@ -1,38 +1,68 @@
import time import time
import traceback import traceback
from flask import jsonify, make_response, request import simplejson as json
import ujson
from flask import Response, jsonify, request
from redis import Redis
from . import openai_bp from llm_server.custom_redis import redis
from ..cache import redis from . import openai_bp, openai_model_bp
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import priority_queue
from ... import opts from ... import opts
from ...database.log_to_db import log_to_db
from ...llm import get_token_count from ...llm import get_token_count
from ...llm.openai.transform import build_openai_response, generate_oai_string from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
# TODO: add rate-limit headers? # TODO: add rate-limit headers?
@openai_bp.route('/completions', methods=['POST']) @openai_bp.route('/completions', methods=['POST'])
def openai_completions(): @openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
def openai_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'): if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else: else:
try: handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
response, status_code = OobaRequestHandler(request).handle_request() if handler.offline:
if status_code != 200: return return_invalid_model_err(model_name)
return status_code
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim:
handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
# The handle_request() call below will load the prompt so we don't have
# to do anything else here.
pass
handler.request_json_body['prompt'] = handler.prompt
if not request_json_body.get('stream'):
invalid_oai_err_msg = validate_oai(request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
response, status_code = handler.handle_request(return_ok=False)
if status_code == 429:
return handler.handle_ratelimited()
output = response.json['results'][0]['text'] output = response.json['results'][0]['text']
# TODO: async/await prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
prompt_tokens = get_token_count(request_json_body['prompt']) response_tokens = get_token_count(output, handler.backend_url)
response_tokens = get_token_count(output) running_model = redis.get('running_model', 'ERROR', dtype=str)
running_model = redis.get('running_model', str, 'ERROR')
response = make_response(jsonify({ response = jsonify({
"id": f"cmpl-{generate_oai_string(30)}", "id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": int(time.time()),
@ -42,7 +72,7 @@ def openai_completions():
"text": output, "text": output,
"index": 0, "index": 0,
"logprobs": None, "logprobs": None,
"finish_reason": None "finish_reason": "stop"
} }
], ],
"usage": { "usage": {
@ -50,12 +80,141 @@ def openai_completions():
"completion_tokens": response_tokens, "completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens "total_tokens": prompt_tokens + response_tokens
} }
}), 200) })
stats = redis.get('proxy_stats', dict) # TODO:
if stats: # stats = redis.get('proxy_stats', dtype=dict)
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] # if stats:
return response # response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
except Exception: return response, 200
traceback.print_exc() else:
return 'Internal Server Error', 500 if not opts.enable_streaming:
return 'Streaming disabled', 403
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
handler.parameters, _ = handler.get_parameters()
handler.request_json_body = {
'prompt': handler.request_json_body['prompt'],
'model': handler.request_json_body['model'],
**handler.parameters
}
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
if not handler.prompt:
# Prevent issues on the backend.
return 'Invalid prompt', 400
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
429,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
try:
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
_, stream_name, error_msg = event.wait()
if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None
return 'Request Timeout', 408
def generate():
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0'
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data'])
if data['error']:
print('OAI streaming encountered error')
yield 'data: [DONE]\n\n'
return
elif data['new']:
response = {
"id": f"cmpl-{oai_string}",
"object": "text_completion",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": data['new']
},
"finish_reason": None
}
]
}
generated_text = generated_text + data['new']
yield f'data: {json.dumps(response)}\n\n'
elif data['completed']:
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
200,
r_url,
handler.backend_url,
)
return
except GeneratorExit:
# This should be triggered if a client disconnects early.
return
except Exception:
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
if event:
redis.publish(f'notifications:{event.event_id}', 'canceled')
if stream_name:
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')
except Exception:
traceback.print_exc()
return 'INTERNAL SERVER', 500

View File

@ -1,7 +1,7 @@
from flask import Response from flask import Response
from . import openai_bp from . import openai_bp
from ..cache import flask_cache from llm_server.custom_redis import flask_cache
from ... import opts from ... import opts

View File

@ -3,59 +3,58 @@ import traceback
import requests import requests
from flask import jsonify from flask import jsonify
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from . import openai_bp from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, flask_cache, redis
from ..stats import server_start_time from ..stats import server_start_time
from ... import opts from ... import opts
from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
from ...helpers import jsonify_pretty from ...helpers import jsonify_pretty
from ...llm.info import get_running_model from ...llm.openai.transform import generate_oai_string
@openai_bp.route('/models', methods=['GET']) @openai_bp.route('/models', methods=['GET'])
@flask_cache.cached(timeout=60, query_string=True) @flask_cache.cached(timeout=60, query_string=True)
def openai_list_models(): def openai_list_models():
model, error = get_running_model() model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
if not model: if not model_name:
response = jsonify({ response = jsonify({
'code': 502, 'code': 502,
'msg': 'failed to reach backend', 'msg': 'failed to reach backend',
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us }), 500 # return 500 so Cloudflare doesn't intercept us
else: else:
running_model = redis.get('running_model', str, 'ERROR') running_model = redis.get('running_model', 'ERROR', dtype=str)
oai = fetch_openai_models() oai = fetch_openai_models()
r = [] r = {
"object": "list",
"data": oai
}
# TODO: verify this works
if opts.openai_expose_our_model: if opts.openai_expose_our_model:
r = [{ r["data"].insert(0, {
"object": "list", "id": running_model,
"data": [ "object": "model",
"created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name,
"permission": [
{ {
"id": running_model, "id": running_model,
"object": "model", "object": "model_permission",
"created": int(server_start_time.timestamp()), "created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name, "allow_create_engine": False,
"permission": [ "allow_sampling": False,
{ "allow_logprobs": False,
"id": running_model, "allow_search_indices": False,
"object": "model_permission", "allow_view": True,
"created": int(server_start_time.timestamp()), "allow_fine_tuning": False,
"allow_create_engine": False, "organization": "*",
"allow_sampling": False, "group": None,
"allow_logprobs": False, "is_blocking": False
"allow_search_indices": False,
"allow_view": True,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False
}
],
"root": None,
"parent": None
} }
] ],
}] "root": None,
response = jsonify_pretty(r + oai), 200 "parent": None
})
response = jsonify_pretty(r), 200
return response return response
@ -64,7 +63,14 @@ def fetch_openai_models():
if opts.openai_api_key: if opts.openai_api_key:
try: try:
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10) response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
return response.json()['data'] j = response.json()['data']
# The "modelperm" string appears to be user-specific, so we'll
# randomize it just to be safe.
for model in range(len(j)):
for p in range(len(j[model]['permission'])):
j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}'
return j
except: except:
traceback.print_exc() traceback.print_exc()
return [] return []

View File

@ -1,7 +1,7 @@
from flask import jsonify from flask import jsonify
from . import openai_bp from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, flask_cache from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache
from ...llm.openai.transform import generate_oai_string from ...llm.openai.transform import generate_oai_string
from ..stats import server_start_time from ..stats import server_start_time
@ -17,7 +17,7 @@ def openai_organizations():
"id": f"org-{generate_oai_string(24)}", "id": f"org-{generate_oai_string(24)}",
"created": int(server_start_time.timestamp()), "created": int(server_start_time.timestamp()),
"title": "Personal", "title": "Personal",
"name": "user-abcdefghijklmnopqrstuvwx", "name": f"user-{generate_oai_string(24)}",
"description": "Personal org for bobjoe@0.0.0.0", "description": "Personal org for bobjoe@0.0.0.0",
"personal": True, "personal": True,
"is_default": True, "is_default": True,

View File

@ -1,14 +1,21 @@
import json import json
import re
import time
import traceback import traceback
from typing import Tuple from typing import Tuple
from uuid import uuid4 from uuid import uuid4
import flask import flask
from flask import jsonify from flask import Response, jsonify, make_response
from llm_server import opts from llm_server import opts
from llm_server.cluster.backend import get_model_choices
from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated from llm_server.database.database import is_api_key_moderated
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.routes.request_handler import RequestHandler from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results from llm_server.workers.moderator import add_moderation_task, get_results
@ -20,20 +27,37 @@ class OpenAIRequestHandler(RequestHandler):
def handle_request(self) -> Tuple[flask.Response, int]: def handle_request(self) -> Tuple[flask.Response, int]:
assert not self.used assert not self.used
if self.offline:
msg = return_invalid_model_err(self.selected_model)
print('OAI Offline:', msg)
return self.handle_error(msg)
if opts.openai_silent_trim: if opts.openai_silent_trim:
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size) oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
else: else:
oai_messages = self.request.json['messages'] oai_messages = self.request.json['messages']
self.prompt = transform_messages_to_prompt(oai_messages) self.prompt = transform_messages_to_prompt(oai_messages)
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
request_valid, invalid_response = self.validate_request() request_valid, invalid_response = self.validate_request()
if not request_valid: if not request_valid:
return invalid_response return invalid_response
if opts.openai_api_key and is_api_key_moderated(self.token): if not self.prompt:
# TODO: format this as an openai error message
return Response('Invalid prompt'), 400
# TODO: support Ooba backend
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
invalid_oai_err_msg = validate_oai(self.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
try: try:
# Gather the last message from the user and all preceeding system messages # Gather the last message from the user and all preceding system messages
msg_l = self.request.json['messages'].copy() msg_l = self.request.json['messages'].copy()
msg_l.reverse() msg_l.reverse()
tag = uuid4() tag = uuid4()
@ -49,33 +73,40 @@ class OpenAIRequestHandler(RequestHandler):
self.prompt = transform_messages_to_prompt(self.request.json['messages']) self.prompt = transform_messages_to_prompt(self.request.json['messages'])
except Exception as e: except Exception as e:
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}') print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc()) traceback.print_exc()
# Reconstruct the request JSON with the validated parameters and prompt.
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
if opts.openai_force_no_hashes:
self.parameters['stop'].append('### ')
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
self.request_json_body['top_p'] = 0.01
llm_request = {**self.parameters, 'prompt': self.prompt} llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request) (success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
model = self.request_json_body.get('model') model = self.request_json_body.get('model')
if success: if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
else: else:
return backend_response, backend_response_status_code return backend_response, backend_response_status_code
def handle_ratelimited(self, do_log: bool = True): def handle_ratelimited(self, do_log: bool = True):
# TODO: return a simulated OpenAI error message model_choices, default_model = get_model_choices()
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another. default_model_info = model_choices[default_model]
return 'Ratelimited', 429 w = int(default_model_info['estimated_wait']) if default_model_info['estimated_wait'] > 0 else 2
response = jsonify({
"error": {
"message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.",
"type": "rate_limit_exceeded",
"param": None,
"code": None
}
})
response.headers['x-ratelimit-limit-requests'] = '2'
response.headers['x-ratelimit-remaining-requests'] = '0'
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
if do_log:
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
# TODO: return a simulated OpenAI error message print('OAI Error:', error_msg)
return jsonify({ return jsonify({
"error": { "error": {
"message": "Invalid request, check your parameters and try again.", "message": "Invalid request, check your parameters and try again.",
@ -84,3 +115,51 @@ class OpenAIRequestHandler(RequestHandler):
"code": None "code": None
} }
}), 400 }), 400
def build_openai_response(self, prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
prompt_tokens = get_token_count(prompt, self.backend_url)
response_tokens = get_token_count(response, self.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
return response
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
self.parameters, parameters_invalid_msg = self.get_parameters()
if not self.parameters:
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
return False, (Response('Invalid request, check your parameters and try again.'), 400)
invalid_oai_err_msg = validate_oai(self.parameters)
if invalid_oai_err_msg:
return False, invalid_oai_err_msg
# self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
# If the parameters were invalid, let the superclass deal with it.
return super().validate_request(prompt, do_log)

View File

@ -1,12 +1,15 @@
import json
import pickle import pickle
import time import time
from typing import Tuple
from uuid import uuid4 from uuid import uuid4
import ujson as json
from redis import Redis from redis import Redis
from llm_server import opts from llm_server import opts
from llm_server.routes.cache import redis from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis
from llm_server.database.database import get_token_ratelimit
def increment_ip_count(client_ip: str, redis_key): def increment_ip_count(client_ip: str, redis_key):
@ -20,24 +23,30 @@ def decrement_ip_count(client_ip: str, redis_key):
class RedisPriorityQueue: class RedisPriorityQueue:
def __init__(self): """
self.redis = Redis(host='localhost', port=6379, db=15) A queue for a specific backend.
self.pubsub = self.redis.pubsub() """
self.pubsub.subscribe('events')
def put(self, item, priority): def __init__(self, name, db: int = 12):
event = DataEvent() self.name = name
self.redis = RedisCustom(name, db=db)
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
# TODO: remove this when we're sure nothing strange is happening
assert item is not None
assert priority is not None
assert selected_model is not None
# Check if the IP is already in the dictionary and if it has reached the limit # Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1]) ip_count = self.get_ip_request_count(item[1])
if ip_count: _, simultaneous_ip = get_token_ratelimit(item[2])
ip_count = int(ip_count) if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
return None # reject the request return None # reject the request
self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority}) timestamp = time.time()
self.increment_ip_count(item[1], 'queued_ip_count') event = DataEvent()
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
return event return event
def get(self): def get(self):
@ -45,31 +54,59 @@ class RedisPriorityQueue:
data = self.redis.zpopmin('queue') data = self.redis.zpopmin('queue')
if data: if data:
item = json.loads(data[0][0]) item = json.loads(data[0][0])
client_ip = item[0][1]
self.decrement_ip_count(client_ip, 'queued_ip_count')
return item return item
time.sleep(0.1) # wait for something to be added to the queue time.sleep(0.1) # wait for something to be added to the queue
def increment_ip_count(self, client_ip: str, redis_key):
self.redis.hincrby(redis_key, client_ip, 1)
def decrement_ip_count(self, client_ip: str, redis_key):
new_count = self.redis.hincrby(redis_key, client_ip, -1)
if new_count <= 0:
self.redis.hdel(redis_key, client_ip)
def __len__(self): def __len__(self):
return self.redis.zcard('queue') return self.redis.zcard('queue')
def get_queued_ip_count(self, client_ip: str): def get_ip_request_count(self, client_ip: str):
q = self.redis.hget('queued_ip_count', client_ip) """
if not q: Get the number of requests in the queue from a specific IP.
return 0 This is a bit inefficient since we iterate over the entire queue, but
return 0 keeps the queue as a single point of truth instead of tracking a separate hashed
set which can get confusing.
If we run into slowdowns in the future, we should go back to the hashed set approach.
:param client_ip:
:return:
"""
start_time = time.time()
items = self.redis.zrange('queue', 0, -1)
count = 0
for item in items:
item_data = json.loads(item)
if item_data[0][1] == client_ip:
count += 1
elapsed_time = time.time() - start_time
if elapsed_time > 0.5:
raise Exception(f"!!! get_ip_request_count took {elapsed_time} seconds to execute !!!")
return count
def flush(self):
self.redis.flush()
def items(self):
return self.redis.zrange('queue', 0, -1)
def cleanup(self):
now = time.time()
for item in self.items():
item_data = json.loads(item)
timestamp = item_data[-2]
if now - timestamp > opts.backend_generate_request_timeout:
self.redis.zrem('queue', 0, item)
event_id = item_data[1]
event = DataEvent(event_id)
event.set((False, None, 'closed'))
print('Removed timed-out item from queue:', event_id)
class DataEvent: class DataEvent:
def __init__(self, event_id=None): """
Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol).
"""
def __init__(self, event_id: str = None):
self.event_id = event_id if event_id else str(uuid4()) self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14) self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub() self.pubsub = self.redis.pubsub()
@ -84,15 +121,89 @@ class DataEvent:
return pickle.loads(item['data']) return pickle.loads(item['data'])
priority_queue = RedisPriorityQueue() def update_active_workers(key: str, operation: str):
if operation == 'incr':
redis.incr(f'active_gen_workers:{key}')
elif operation == 'decr':
redis.decr(f'active_gen_workers:{key}')
if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0:
redis.set(f'active_gen_workers:{key}', 0)
def incr_active_workers(): def incr_active_workers(selected_model: str, backend_url: str):
redis.incr('active_gen_workers') update_active_workers(selected_model, 'incr')
update_active_workers(backend_url, 'incr')
def decr_active_workers(): def decr_active_workers(selected_model: str, backend_url: str):
redis.decr('active_gen_workers') update_active_workers(selected_model, 'decr')
new_count = redis.get('active_gen_workers', int, 0) update_active_workers(backend_url, 'decr')
if new_count < 0:
redis.set('active_gen_workers', 0)
class PriorityQueue:
"""
Helper class to wrangler all the different queues.
"""
def __init__(self, backends: set = None):
"""
Only have to load the backends once.
:param backends:
"""
self.redis = Redis(host='localhost', port=6379, db=9)
if backends:
for item in backends:
self.redis.sadd('backends', item)
def get_backends(self):
return {x.decode('utf-8') for x in self.redis.smembers('backends')}
def get_queued_ip_count(self, client_ip: str):
count = 0
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
count += queue.get_ip_request_count(client_ip)
return count
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
queue = RedisPriorityQueue(backend_url)
return queue.put(item, priority, selected_model, do_stream)
def activity(self):
lines = []
status_redis = RedisCustom('worker_status')
for worker in status_redis.keys():
lines.append((worker, status_redis.getp(worker)))
return sorted(lines)
def len(self, model_name):
count = 0
backends_with_models = set()
for k in self.get_backends():
info = cluster_config.get_backend(k)
if info.get('model') == model_name:
backends_with_models.add(k)
for backend_url in backends_with_models:
count += len(RedisPriorityQueue(backend_url))
return count
def __len__(self):
count = 0
p = set()
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
p.add((backend_url, len(queue)))
count += len(queue)
return count
def flush(self):
for k in self.redis.keys():
q = json.loads(self.redis.get(k))
q.flush()
self.redis.set(k, json.dumps(q))
def flush_db(self):
self.redis.flushdb()
priority_queue = PriorityQueue()

View File

@ -5,23 +5,22 @@ import flask
from flask import Response, request from flask import Response, request
from llm_server import opts from llm_server import opts
from llm_server.database.conn import database from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
from llm_server.database.database import log_prompt from llm_server.custom_redis import redis
from llm_server.database.database import get_token_ratelimit
from llm_server.database.log_to_db import log_to_db
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token from llm_server.routes.auth import parse_token
from llm_server.routes.cache import redis
from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
DEFAULT_PRIORITY = 9999
class RequestHandler: class RequestHandler:
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None): def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
self.request = incoming_request self.request = incoming_request
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true' # self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
# Routes need to validate it, here we just load it # Routes need to validate it, here we just load it
if incoming_json: if incoming_json:
@ -34,11 +33,38 @@ class RequestHandler:
self.start_time = time.time() self.start_time = time.time()
self.client_ip = self.get_client_ip() self.client_ip = self.get_client_ip()
self.token = self.get_auth_token() self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit() self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
self.backend = get_backend()
self.parameters = None self.parameters = None
self.used = False self.used = False
redis.zadd('recent_prompters', {self.client_ip: time.time()})
# This is null by default since most handlers need to transform the prompt in a specific way.
self.prompt = None
self.selected_model = selected_model
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
# Debug stuff
# if not self.cluster_backend_info.get('mode'):
# print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
# if not self.cluster_backend_info.get('model'):
# print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
# if not self.cluster_backend_info.get('model_config'):
# print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
self.offline = True
else:
self.offline = False
self.selected_model = self.cluster_backend_info['model']
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
if self.token and not self.token.startswith('SYSTEM__'):
# "recent_prompters" is only used for stats.
redis.zadd('recent_prompters', {self.client_ip: time.time()})
def check_online(self) -> bool:
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
return self.cluster_backend_info['online']
def get_auth_token(self): def get_auth_token(self):
if self.request_json_body.get('X-API-KEY'): if self.request_json_body.get('X-API-KEY'):
@ -49,6 +75,8 @@ class RequestHandler:
return parse_token(self.request.headers['Authorization']) return parse_token(self.request.headers['Authorization'])
def get_client_ip(self): def get_client_ip(self):
if self.request.headers.get('Llm-Connecting-Ip'):
return self.request.headers['Llm-Connecting-Ip']
if self.request.headers.get('X-Connecting-IP'): if self.request.headers.get('X-Connecting-IP'):
return self.request.headers.get('X-Connecting-IP') return self.request.headers.get('X-Connecting-IP')
elif self.request.headers.get('Cf-Connecting-Ip'): elif self.request.headers.get('Cf-Connecting-Ip'):
@ -58,26 +86,7 @@ class RequestHandler:
else: else:
return self.request.remote_addr return self.request.remote_addr
def get_token_ratelimit(self):
priority = DEFAULT_PRIORITY
simultaneous_ip = opts.simultaneous_requests_per_ip
if self.token:
cursor = database.cursor()
try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
result = cursor.fetchone()
if result:
priority, simultaneous_ip = result
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip
def get_parameters(self): def get_parameters(self):
if self.request_json_body.get('max_tokens'):
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body) parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg return parameters, parameters_invalid_msg
@ -119,7 +128,7 @@ class RequestHandler:
backend_response = self.handle_error(combined_error_message, 'Validation Error') backend_response = self.handle_error(combined_error_message, 'Validation Error')
if do_log: if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
return False, backend_response return False, backend_response
return True, (None, 0) return True, (None, 0)
@ -131,14 +140,18 @@ class RequestHandler:
request_valid, invalid_response = self.validate_request(prompt, do_log=True) request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid: if not request_valid:
return (False, None, None, 0), invalid_response return (False, None, None, 0), invalid_response
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority) event = priority_queue.put(self.backend_url, (llm_request, self.client_ip, self.token, self.parameters), self.token_priority, self.selected_model)
else: else:
event = None event = None
if not event: if not event:
return (False, None, None, 0), self.handle_ratelimited() return (False, None, None, 0), self.handle_ratelimited()
# TODO: add wait timeout
success, response, error_msg = event.wait() success, response, error_msg = event.wait()
if error_msg == 'closed':
return (False, None, None, 0), (self.handle_error('Request Timeout')[0], 408)
end_time = time.time() end_time = time.time()
elapsed_time = end_time - self.start_time elapsed_time = end_time - self.start_time
@ -160,7 +173,17 @@ class RequestHandler:
else: else:
error_msg = error_msg.strip('.') + '.' error_msg = error_msg.strip('.') + '.'
backend_response = self.handle_error(error_msg) backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) log_to_db(ip=self.client_ip,
token=self.token,
prompt=prompt,
response=backend_response[0].data.decode('utf-8'),
gen_time=None,
parameters=self.parameters,
headers=dict(self.request.headers),
backend_response_code=response_status_code,
request_url=self.request.url,
backend_url=self.backend_url,
is_error=True)
return (False, None, None, 0), backend_response return (False, None, None, 0), backend_response
# =============================================== # ===============================================
@ -180,7 +203,7 @@ class RequestHandler:
if return_json_err: if return_json_err:
error_msg = 'The backend did not return valid JSON.' error_msg = 'The backend did not return valid JSON.'
backend_response = self.handle_error(error_msg) backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
return (False, None, None, 0), backend_response return (False, None, None, 0), backend_response
# =============================================== # ===============================================
@ -189,22 +212,29 @@ class RequestHandler:
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def is_client_ratelimited(self) -> bool: def is_client_ratelimited(self) -> bool:
if self.token_priority == 0:
return False
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip)) queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
x = redis.hget('processing_ips', self.client_ip) x = redis.hget('processing_ips', self.client_ip)
if x: if x:
processing_ip = int(x) processing_ip = int(x)
else: else:
processing_ip = 0 processing_ip = 0
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
return False if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
else: print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
return True return True
else:
return False
def handle_request(self) -> Tuple[flask.Response, int]: def handle_request(self) -> Tuple[flask.Response, int]:
# Must include this in your child. # Must include this in your child.
# if self.used: # assert not self.used
# raise Exception('Can only use a RequestHandler object once.') # if self.offline:
# msg = f'{self.selected_model} is not a valid model choice.'
# print(msg)
# return format_sillytavern_err(msg)
raise NotImplementedError raise NotImplementedError
def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]: def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]:
@ -214,11 +244,11 @@ class RequestHandler:
raise NotImplementedError raise NotImplementedError
def get_backend(): def get_backend_handler(mode, backend_url: str):
if opts.mode == 'oobabooga': if mode == 'oobabooga':
return OobaboogaBackend() return OobaboogaBackend(backend_url)
elif opts.mode == 'vllm': elif mode == 'vllm':
return VLLMBackend() return VLLMBackend(backend_url)
else: else:
raise Exception raise Exception

View File

@ -1,3 +1,3 @@
def handle_server_error(e): def handle_server_error(e):
print(e) print('Internal Error:', e)
return {'error': True}, 500 return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500

View File

@ -1,33 +1,11 @@
from datetime import datetime from datetime import datetime
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
from llm_server.helpers import round_up_base
# proompters_5_min = 0
# concurrent_semaphore = Semaphore(concurrent_gens)
server_start_time = datetime.now() server_start_time = datetime.now()
# TODO: do I need this?
# def elapsed_times_cleanup():
# global wait_in_queue_elapsed
# while True:
# current_time = time.time()
# with wait_in_queue_elapsed_lock:
# global wait_in_queue_elapsed
# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
# time.sleep(1)
def calculate_avg_gen_time():
# Get the average generation time from Redis
average_generation_time = redis.get('average_generation_time')
if average_generation_time is None:
return 0
else:
return float(average_generation_time)
def get_total_proompts(): def get_total_proompts():
count = redis.get('proompts') count = redis.get('proompts')
if count is None: if count is None:
@ -37,10 +15,27 @@ def get_total_proompts():
return count return count
def get_active_gen_workers(): def get_active_gen_workers_model(selected_model: str = None):
active_gen_workers = redis.get('active_gen_workers') return redis.get(f'active_gen_workers:{selected_model}', dtype=int, default=0)
if active_gen_workers is None:
count = 0
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
if active_gen_workers < concurrent_gens:
return 0
elif active_gen_workers >= concurrent_gens:
# Calculate how long it will take to complete the currently running gens and the queued requests.
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
# Add gen_time_calc to the time to account for the currently running generations.
# This assumes that all active workers will finish at the same time, which is unlikely.
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
elif proompters_in_queue == 0 and active_gen_workers == 0:
# No queue, no workers
return 0
else: else:
count = int(active_gen_workers) # No queue
return count return gen_time_calc

View File

@ -3,18 +3,18 @@ import traceback
from flask import jsonify, request from flask import jsonify, request
from . import bp from . import bp
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
@bp.route('/generate', methods=['POST']) @bp.route('/v1/generate', methods=['POST'])
def generate(): @bp.route('/<model_name>/v1/generate', methods=['POST'])
def generate(model_name=None):
request_valid_json, request_json_body = validate_json(request) request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'): if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400 return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else: else:
handler = OobaRequestHandler(request) handler = OobaRequestHandler(request, selected_model=model_name)
try: try:
return handler.handle_request() return handler.handle_request()
except Exception: except Exception:

View File

@ -2,83 +2,30 @@ import time
from datetime import datetime from datetime import datetime
from llm_server import opts from llm_server import opts
from llm_server.cluster.backend import get_model_choices
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.database.database import get_distinct_ips_24h, sum_column from llm_server.database.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort, round_up_base from llm_server.helpers import deep_sort
from llm_server.llm.info import get_running_model from llm_server.routes.stats import get_total_proompts, server_start_time
from llm_server.netdata import get_power_states
from llm_server.routes.cache import redis
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
if active_gen_workers < concurrent_gens:
return 0
elif active_gen_workers >= concurrent_gens:
# Calculate how long it will take to complete the currently running gens and the queued requests.
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
# Add gen_time_calc to the time to account for the currently running generations.
# This assumes that all active workers will finish at the same time, which is unlikely.
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
elif proompters_in_queue == 0 and active_gen_workers == 0:
# No queue, no workers
return 0
else:
# No queue
return gen_time_calc
# TODO: have routes/__init__.py point to the latest API version generate_stats()
def generate_stats(regen: bool = False): def generate_stats(regen: bool = False):
if not regen: if not regen:
c = redis.get('proxy_stats', dict) c = redis.getp('proxy_stats')
if c: if c:
return c return c
model_name, error = get_running_model() # will return False when the fetch fails model_choices, default_model = get_model_choices(regen=True)
if isinstance(model_name, bool):
online = False
else:
online = True
redis.set('running_model', model_name)
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change base_client_api = redis.get('base_client_api', dtype=str)
# if len(t) == 0:
# estimated_wait = 0
# else:
# waits = [elapsed for end, elapsed in t]
# estimated_wait = int(sum(waits) / len(waits))
active_gen_workers = get_active_gen_workers()
proompters_in_queue = len(priority_queue)
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
if opts.netdata_root:
netdata_stats = {}
power_states = get_power_states()
for gpu, power_state in power_states.items():
netdata_stats[gpu] = {
'power_state': power_state,
# 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
}
else:
netdata_stats = {}
base_client_api = redis.get('base_client_api', str)
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf')) proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
output = { output = {
'models': {
'choices': model_choices,
'default': default_model,
},
'stats': { 'stats': {
'proompters': { 'proompters': {
'5_min': proompters_5_min, '5_min': proompters_5_min,
@ -86,39 +33,49 @@ def generate_stats(regen: bool = False):
}, },
'proompts_total': get_total_proompts() if opts.show_num_prompts else None, 'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
'average_generation_elapsed_sec': int(average_generation_time),
# 'estimated_avg_tps': estimated_avg_tps, # 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None, 'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
'num_backends': len(cluster_config.all()) if opts.show_backends else None,
}, },
'online': online,
'endpoints': { 'endpoints': {
'blocking': f'https://{base_client_api}', 'blocking': f'https://{base_client_api}',
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, 'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
}, },
'queue': {
'processing': active_gen_workers,
'queued': proompters_in_queue,
'estimated_wait_sec': int(estimated_wait_sec),
},
'timestamp': int(time.time()), 'timestamp': int(time.time()),
'config': { 'config': {
'gatekeeper': 'none' if opts.auth_required is False else 'token', 'gatekeeper': 'none' if opts.auth_required is False else 'token',
'context_size': opts.context_size,
'concurrent': opts.concurrent_gens,
'model': opts.manual_model_name if opts.manual_model_name else model_name,
'mode': opts.mode,
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip, 'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
'api_mode': opts.frontend_api_mode
}, },
'keys': { 'keys': {
'openaiKeys': '', 'openaiKeys': '',
'anthropicKeys': '', 'anthropicKeys': '',
}, },
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, 'backends': {},
'nvidia': netdata_stats 'online': len(model_choices) > 0
} }
# TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself
if opts.show_backends:
for backend_url, v in cluster_config.all().items():
backend_info = cluster_config.get_backend(backend_url)
if not backend_info['online']:
continue
backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None
output['backends'][backend_info['hash']] = {
'uptime': backend_uptime,
'max_tokens': backend_info['model_config'].get('max_position_embeddings', -1),
'model': backend_info['model'],
'mode': backend_info['mode'],
'nvidia': backend_info['nvidia'],
'priority': backend_info['priority'],
}
result = deep_sort(output) result = deep_sort(output)
# It may take a bit to get the base client API, so don't cache until then. # It may take a bit to get the base client API, so don't cache until then.
if base_client_api: if base_client_api:
redis.set_dict('proxy_stats', result) # Cache with no expiry redis.setp('proxy_stats', result)
return result return result

View File

@ -1,186 +1,200 @@
import json import json
import threading
import time import time
import traceback import traceback
from typing import Union
import ujson
from flask import request from flask import request
from redis import Redis
from ..cache import redis from . import bp
from ..helpers.http import require_api_key, validate_json from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ..queue import priority_queue
from ... import opts from ... import opts
from ...database.database import log_prompt from ...custom_redis import redis
from ...llm.generator import generator from ...database.log_to_db import log_to_db
from ...llm.vllm import tokenize from ...sock import sock
from ...stream import sock
# TODO: have workers process streaming requests # Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
# TODO: make sure to log the token as well (seems to be missing in the DB right now) # We solve this by splitting the routes
@sock.route('/api/v1/stream') @bp.route('/v1/stream')
def stream(ws): @bp.route('/<model_name>/v1/stream')
def send_err_and_quit(quitting_err_msg): def stream(model_name=None):
ws.send(json.dumps({ return 'This is a websocket endpoint.', 400
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close()
log_in_bg(quitting_err_msg, is_error=True)
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
def background_task_exception(): @sock.route('/v1/stream', bp=bp)
generated_tokens = tokenize(generated_text_bg) def stream_without_model(ws):
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error) do_stream(ws, model_name=None)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
if not opts.enable_streaming: @sock.route('/<model_name>/v1/stream', bp=bp)
return 'Streaming is disabled', 401 def stream_with_model(ws, model_name=None):
do_stream(ws, model_name)
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else:
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
auth_failure = require_api_key(request_json_body) def do_stream(ws, model_name):
if auth_failure: event_id = None
return auth_failure try:
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close()
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None,
is_error=True
)
handler = OobaRequestHandler(request, request_json_body) if not opts.enable_streaming:
generated_text = '' return 'Streaming disabled', 403
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
err_msg = None r_headers = dict(request.headers)
if handler.is_client_ratelimited(): r_url = request.url
r, _ = handler.handle_ratelimited(do_log=False) message_num = 0
err_msg = r.json['results'][0]['text']
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else: else:
request_valid, invalid_response = handler.validate_request(prompt=input_prompt) # We have to do auth ourselves since the details are sent in the message.
if not request_valid: auth_failure = require_api_key(request_json_body)
err_msg = invalid_response[0].json['results'][0]['text'] if auth_failure:
if err_msg: return auth_failure
send_err_and_quit(err_msg)
return
llm_request = { handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
**handler.parameters, if handler.offline:
'prompt': input_prompt, msg = f'{handler.selected_model} is not a valid model choice.'
'stream': True, print(msg)
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority)
if not event:
r, _ = handler.handle_ratelimited()
err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg)
return
try:
response = generator(llm_request)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'text_stream', 'event': 'text_stream',
'message_num': message_num, 'message_num': 0,
'text': msg 'text': msg
})) }))
return
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
err_msg = None
if handler.is_client_ratelimited():
r, _ = handler.handle_ratelimited(do_log=False)
err_msg = r.json['results'][0]['text']
else: else:
# Be extra careful when getting attributes from the response object request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
try: if not request_valid:
response_status_code = response.status_code err_msg = invalid_response[0].json['results'][0]['text']
except: if err_msg:
response_status_code = 0 send_err_and_quit(err_msg)
return
partial_response = b'' handler.parameters, _ = handler.get_parameters()
handler.prompt = input_prompt
handler.request_json_body = {
'prompt': handler.prompt,
**handler.parameters
}
for chunk in response.iter_content(chunk_size=1): event = None
partial_response += chunk if not handler.is_client_ratelimited():
if partial_response.endswith(b'\x00'): event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
json_strs = partial_response.split(b'\x00') if not event:
for json_str in json_strs: r = handler.handle_ratelimited()
if json_str: send_err_and_quit(r[0].data)
try: return
json_obj = json.loads(json_str.decode()) event_id = event.event_id
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
try:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
except:
# The has client closed the stream.
if request:
request.close()
ws.close()
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
return
_, stream_name, error_msg = event.wait()
if error_msg:
print('Stream failed to start streaming:', error_msg)
ws.close(reason=1014, message='Request Timeout')
return
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0' # The ID of the last entry we read.
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
return
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
data = ujson.loads(item[b'data'])
if data['error']:
print(data['error'])
send_err_and_quit('Encountered exception while streaming.')
return
elif data['new']:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': data['new']
}))
message_num += 1 message_num += 1
partial_response = b'' # Reset the partial response generated_text = generated_text + data['new']
elif data['completed']:
# If there is no more data, break the loop return
if not chunk: except:
break send_err_and_quit('Encountered exception while streaming.')
traceback.print_exc()
end_time = time.time() finally:
elapsed_time = end_time - start_time try:
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code) ws.send(json.dumps({
except: 'event': 'stream_end',
traceback.print_exc() 'message_num': message_num
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text'] }))
ws.send(json.dumps({ except:
'event': 'text_stream', # The client closed the stream.
'message_num': message_num, pass
'text': generated_text if stream_name:
})) stream_redis.delete(stream_name)
if request: end_time = time.time()
request.close() elapsed_time = end_time - start_time
ws.close() log_to_db(ip=handler.client_ip,
log_in_bg(generated_text, is_error=True, status_code=response_status_code) token=handler.token,
return prompt=input_prompt,
finally: response=generated_text,
# The worker incremented it, we'll decrement it. gen_time=elapsed_time,
decrement_ip_count(handler.client_ip, 'processing_ips') parameters=handler.parameters,
decr_active_workers() headers=r_headers,
try: backend_response_code=response_status_code,
ws.send(json.dumps({ request_url=r_url,
'event': 'stream_end', backend_url=handler.backend_url
'message_num': message_num )
})) finally:
except: if event_id:
# The client closed the stream. redis.publish(f'notifications:{event_id}', 'canceled')
end_time = time.time() try:
elapsed_time = end_time - start_time # Must close the connection or greenlets will complain.
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) ws.close()
ws.close() # this is important if we encountered and error and exited early. except:
pass

View File

@ -2,22 +2,16 @@ import time
from flask import jsonify, request from flask import jsonify, request
from llm_server.custom_redis import flask_cache
from . import bp from . import bp
from ..auth import requires_auth
from ..cache import flask_cache
from ... import opts from ... import opts
from ...llm.info import get_running_model from ...cluster.backend import get_backends_from_model, is_valid_model
from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
# @bp.route('/info', methods=['GET']) @bp.route('/v1/model', methods=['GET'])
# # @cache.cached(timeout=3600, query_string=True) @bp.route('/<model_name>/v1/model', methods=['GET'])
# def get_info(): def get_model(model_name=None):
# # requests.get()
# return 'yes'
@bp.route('/model', methods=['GET'])
def get_model():
# We will manage caching ourself since we don't want to cache # We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors. # when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url cache_key = 'model_cache::' + request.url
@ -26,24 +20,21 @@ def get_model():
if cached_response: if cached_response:
return cached_response return cached_response
model_name, error = get_running_model()
if not model_name: if not model_name:
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
if not is_valid_model(model_name):
response = jsonify({ response = jsonify({
'code': 502, 'code': 400,
'msg': 'failed to reach backend', 'msg': 'Model does not exist.',
'type': error.__class__.__name__ }), 400
}), 500 # return 500 so Cloudflare doesn't intercept us
else: else:
num_backends = len(get_backends_from_model(model_name))
response = jsonify({ response = jsonify({
'result': opts.manual_model_name if opts.manual_model_name else model_name, 'result': opts.manual_model_name if opts.manual_model_name else model_name,
'model_backend_count': num_backends,
'timestamp': int(time.time()) 'timestamp': int(time.time())
}), 200 }), 200
flask_cache.set(cache_key, response, timeout=60) flask_cache.set(cache_key, response, timeout=60)
return response return response
@bp.route('/backend', methods=['GET'])
@requires_auth
def get_backend():
return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200

View File

@ -1,8 +1,10 @@
from flask import jsonify from flask import jsonify
from llm_server.custom_redis import flask_cache
from . import bp from . import bp
from .generate_stats import generate_stats from .generate_stats import generate_stats
from ..cache import flask_cache from ..auth import requires_auth
from ...cluster.cluster_config import cluster_config, get_backends
from ...helpers import jsonify_pretty from ...helpers import jsonify_pretty
@ -10,3 +12,14 @@ from ...helpers import jsonify_pretty
@flask_cache.cached(timeout=5, query_string=True) @flask_cache.cached(timeout=5, query_string=True)
def get_stats(): def get_stats():
return jsonify_pretty(generate_stats()) return jsonify_pretty(generate_stats())
@bp.route('/backends', methods=['GET'])
@requires_auth
def get_backend():
online, offline = get_backends()
result = {}
for i in online + offline:
info = cluster_config.get_backend(i)
result[info['hash']] = info
return jsonify(result), 200

View File

@ -3,6 +3,6 @@ from flask_sock import Sock
sock = Sock() sock = Sock()
def init_socketio(app): def init_wssocket(app):
global sock global sock
sock.init_app(app) sock.init_app(app)

View File

@ -1,35 +0,0 @@
from threading import Thread
from .blocking import start_workers
from .main import main_background_thread
from .moderator import start_moderation_workers
from .printer import console_printer
from .recent import recent_prompters_thread
from .threads import cache_stats
from .. import opts
def start_background():
start_workers(opts.concurrent_gens)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
print('Started the main background thread.')
start_moderation_workers(opts.openai_moderation_workers)
t = Thread(target=cache_stats)
t.daemon = True
t.start()
print('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
print('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
print('Started the console printer.')

View File

@ -1,51 +0,0 @@
import threading
import time
from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
def worker():
while True:
need_to_wait()
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
need_to_wait()
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers()
if not request_json_body:
# This was a dummy request from the websocket handler.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
continue
try:
success, response, error_msg = generator(request_json_body)
event = DataEvent(event_id)
event.set((success, response, error_msg))
finally:
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers()
def start_workers(num_workers: int):
i = 0
for _ in range(num_workers):
t = threading.Thread(target=worker)
t.daemon = True
t.start()
i += 1
print(f'Started {i} inference workers.')
def need_to_wait():
# We need to check the number of active workers since the streaming endpoint may be doing something.
active_workers = redis.get('active_gen_workers', int, 0)
s = time.time()
while active_workers >= opts.concurrent_gens:
time.sleep(0.01)
e = time.time()
if e - s > 0.5:
print(f'Worker was delayed {e - s} seconds.')

View File

@ -0,0 +1,32 @@
import time
from redis import Redis
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
# NOT NEEDED
def cleaner():
r = Redis(db=8)
stream_info = {}
while True:
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
processed_streams = []
for stream in all_streams:
stream = stream.decode()
current_size = r.xlen(stream)
# If the stream is new or its size has changed, update the size and time in the dictionary
if stream not in stream_info or current_size != stream_info[stream]['size']:
stream_info[stream] = {'size': current_size, 'time': time.time()}
processed_streams.append(stream)
else:
# If the size hasn't changed for 5 minutes, delete the stream
if time.time() - stream_info[stream]['time'] >= 300:
r.delete(stream)
print(f"Stream '{stream}' deleted due to inactivity.")
del stream_info[stream]
time.sleep(60)

View File

@ -0,0 +1,160 @@
import json
import threading
import time
import traceback
from uuid import uuid4
import ujson
from redis import Redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis
from llm_server.llm.generator import generator
from llm_server.logging import create_logger
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
stream_redis = Redis(db=8)
STREAM_NAME_PREFIX = 'stream'
def check_cancellation(event, event_id):
"""
This thread checks the pub/sub channel in the background so the main process
isn't bogged down with Redis calls. Otherwise, the main process slows down to 1 token/sec.
:param event:
:param event_id:
:return:
"""
pubsub = redis.pubsub()
pubsub.subscribe(f'notifications:{event_id}')
while not event.is_set():
message = pubsub.get_message()
if message and message['data'] == b'canceled':
event.set()
time.sleep(0.5) # check every half second
def get_stream_name(name: str):
return f'{STREAM_NAME_PREFIX}:{name}'
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
logger = create_logger('inferencer')
prompt = msg_to_backend['prompt']
stream_name = get_stream_name(stream_name)
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
event = threading.Event()
threading.Thread(target=check_cancellation, args=(event, event_id)).start()
try:
response = generator(msg_to_backend, backend_url)
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
# If there is no more data, break the loop
if not chunk:
break
if event.is_set():
logger.debug('Client canceled generation')
response.close()
return
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})})
except AttributeError as e:
if str(e) == "'bool' object has no attribute 'iter_content'":
# We don't care about these errors.
logger.debug('failed to stream from backend - no response')
else:
raise
except Exception as e:
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
raise # We won't handle the exception here.
finally:
# Publish final message to Redis stream
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})})
event.set() # stop the cancellation checking thread
#
def worker(backend_url):
logger = create_logger('inferencer')
status_redis = RedisCustom('worker_status')
worker_id = str(uuid4())
status_redis.setp(str(worker_id), None)
redis_queue = RedisPriorityQueue(backend_url)
while True:
status_redis.setp(str(worker_id), 'waiting...')
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
event = DataEvent(event_id)
try:
backend_info = cluster_config.get_backend(backend_url)
except:
# This is not a critical error because it usually means that the backend is
# offline and this backend is in a state of transition from online to offline.
logger.debug(f'got an exception while getting info for backend {backend_url} - ', traceback.format_exc())
event.set((False, None, 'exception'))
continue
if not backend_info['online']:
event.set((False, None, 'canceled'))
continue
if not selected_model:
selected_model = backend_info['model']
logger.debug(f"Starting using {backend_url} and {selected_model}. Online: {backend_info['online']}. Streaming: {do_stream}")
try:
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers(selected_model, backend_url)
if do_stream:
status_redis.setp(str(worker_id), ('streaming', client_ip))
# Return the name of the stream that the slave should connect to.
event.set((True, get_stream_name(worker_id), None))
msg_to_backend = {
**parameters,
'prompt': request_json_body['prompt'],
'stream': True,
}
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
else:
# Normal inference (not streaming).
status_redis.setp(str(worker_id), ('generating', client_ip))
success, response, error_msg = generator(request_json_body, backend_url)
event.set((success, response, error_msg))
except:
logger.error(traceback.format_exc())
event.set((False, None, 'exception'))
finally:
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), None)
def start_workers(cluster: dict):
logger = create_logger('inferencer')
i = 0
for item in cluster:
for _ in range(item['concurrent_gens']):
t = threading.Thread(target=worker, args=(item['backend_url'],))
t.daemon = True
t.start()
i += 1
logger.info(f'Started {i} inference workers.')

View File

@ -0,0 +1,31 @@
import pickle
import traceback
import redis
from llm_server.database.database import do_db_log
def db_logger():
"""
We don't want the logging operation to be blocking, so we will use a background worker
to do the logging.
:return:
"""
r = redis.Redis(host='localhost', port=6379, db=3)
p = r.pubsub()
p.subscribe('database-logger')
for message in p.listen():
try:
if message['type'] == 'message':
data = pickle.loads(message['data'])
function_name = data['function']
args = data['args']
kwargs = data['kwargs']
if function_name == 'log_prompt':
do_db_log(*args, **kwargs)
except:
traceback.print_exc()

View File

@ -1,56 +0,0 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis
def main_background_thread():
redis.set('average_generation_elapsed_sec', 0)
redis.set('estimated_avg_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
while True:
# TODO: unify this
if opts.mode == 'oobabooga':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
elif opts.mode == 'vllm':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
redis.set('estimated_avg_tps', estimated_avg_tps)
time.sleep(60)

View File

@ -0,0 +1,57 @@
import time
import requests
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config, get_backends
from llm_server.custom_redis import redis
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_info
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
def main_background_thread():
while True:
online, offline = get_backends()
for backend_url in online:
backend_info = cluster_config.get_backend(backend_url)
backend_mode = backend_info['mode']
backend_info = get_info(backend_url, backend_mode)
running_model = backend_info.get('model')
if not running_model:
continue
average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode)
if average_generation_elapsed_sec: # returns None on exception
cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec)
if average_output_tokens:
cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens)
if average_generation_elapsed_sec and average_output_tokens:
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
if opts.background_homepage_cacher:
try:
base_client_api = redis.get('base_client_api', dtype=str)
r = requests.get('https://' + base_client_api, timeout=5)
except Exception as e:
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
backends = priority_queue.get_backends()
for backend_url in backends:
queue = RedisPriorityQueue(backend_url)
queue.cleanup()
time.sleep(30)
def calc_stats_for_backend(backend_url, running_model, backend_mode):
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps

View File

@ -1,10 +1,13 @@
import json import json
import threading import threading
import time
import traceback import traceback
import redis as redis_redis import redis as redis_redis
from llm_server import opts
from llm_server.llm.openai.moderation import check_moderation_endpoint from llm_server.llm.openai.moderation import check_moderation_endpoint
from llm_server.logging import create_logger
redis_moderation = redis_redis.Redis() redis_moderation = redis_redis.Redis()
@ -16,36 +19,43 @@ def start_moderation_workers(num_workers):
t.daemon = True t.daemon = True
t.start() t.start()
i += 1 i += 1
print(f'Started {i} moderation workers.')
def moderation_worker(): # TODO: don't use UUID tags to identify items. Use native redis
while True:
result = redis_moderation.blpop('queue:msgs_to_check')
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
print(result)
traceback.print_exc()
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
def get_results(tag, num_tasks): def get_results(tag, num_tasks):
tag = str(tag) # Required for comparison with Redis results. tag = str(tag) # Cast a UUID4 to a string.
flagged_categories = set() flagged_categories = set()
num_results = 0 num_results = 0
start_time = time.time()
while num_results < num_tasks: while num_results < num_tasks:
result = redis_moderation.blpop('queue:flagged_categories') result = redis_moderation.blpop(['queue:flagged_categories'], timeout=opts.openai_moderation_timeout)
if result is None:
break # Timeout occurred, break the loop.
result_tag, categories = json.loads(result[1]) result_tag, categories = json.loads(result[1])
if result_tag == tag: if result_tag == tag:
if categories: if categories:
for item in categories: for item in categories:
flagged_categories.add(item) flagged_categories.add(item)
num_results += 1 num_results += 1
if time.time() - start_time > opts.openai_moderation_timeout:
logger.warning('Timed out waiting for result from moderator')
break
return list(flagged_categories) return list(flagged_categories)
def moderation_worker():
logger = create_logger('moderator')
while True:
result = redis_moderation.blpop(['queue:msgs_to_check'])
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
logger.error(traceback.format_exc())
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))

View File

@ -1,25 +1,34 @@
import logging
import time import time
import traceback
from llm_server.routes.cache import redis from llm_server.cluster.backend import get_running_models
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.logging import create_logger
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
logger = logging.getLogger('console_printer')
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
def console_printer(): def console_printer():
logger = create_logger('console_printer')
time.sleep(3) time.sleep(3)
while True: while True:
processing = redis.hkeys('processing_ips') try:
processing_count = 0 processing = redis.keys('active_gen_workers:http*') # backends always start with http
for ip in processing: processing_count = 0
processing_count += int(redis.hget('processing_ips', ip)) if len(processing):
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}') for k in processing:
processing_count += redis.get(k, default=0, dtype=int)
backends = [k for k, v in cluster_config.all().items() if v['online']]
activity = priority_queue.activity()
# Calculate the queue size the same way it's done on the stats.
queue_size = 0
running_models = get_running_models()
for model in running_models:
queue_size += priority_queue.len(model)
# Active Workers and Processing should read the same. If not, that's an issue.
logger.info(f'Active Workers: {len([i for i in activity if (i[1] and i[1] != "waiting...")])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {len(backends)}')
except:
logger.error(traceback.format_exc())
time.sleep(10) time.sleep(10)

View File

@ -1,6 +1,6 @@
import time import time
from llm_server.routes.cache import redis from llm_server.custom_redis import redis
def recent_prompters_thread(): def recent_prompters_thread():

View File

@ -0,0 +1,58 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.cluster.worker import cluster_worker
from llm_server.logging import create_logger
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers
from llm_server.workers.logger import db_logger
from llm_server.workers.mainer import main_background_thread
from llm_server.workers.moderator import start_moderation_workers
from llm_server.workers.printer import console_printer
from llm_server.workers.recenter import recent_prompters_thread
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)
def start_background():
logger = create_logger('threader')
start_workers(opts.cluster)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
logger.info('Started the main background thread.')
num_moderators = opts.cluster_workers * 3
start_moderation_workers(num_moderators)
logger.info(f'Started {num_moderators} moderation workers.')
t = Thread(target=cache_stats)
t.daemon = True
t.start()
logger.info('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
logger.info('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
logger.info('Started the console logger.infoer.')
t = Thread(target=cluster_worker)
t.daemon = True
t.start()
logger.info('Started the cluster worker.')
t = Thread(target=db_logger)
t.daemon = True
t.start()
logger.info('Started background logger.')

View File

@ -1,9 +0,0 @@
import time
from llm_server.routes.v1.generate_stats import generate_stats
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)

103
other/gradio/gradio_chat.py Normal file
View File

@ -0,0 +1,103 @@
import os
import sys
import time
import traceback
import warnings
from threading import Thread
import gradio as gr
import openai
import requests
warnings.filterwarnings("ignore")
API_BASE = os.getenv('API_BASE')
if not API_BASE:
print('Must set the secret variable API_BASE to your https://your-site/api')
sys.exit(1)
API_BASE = API_BASE.strip('/')
APP_TITLE = os.getenv('APP_TITLE')
PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE')
TRACKING_CODE = os.getenv('TRACKING_CODE')
def background():
while True:
previous = openai.api_base
try:
r = requests.get(API_BASE + '/stats').json()
if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys():
openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1'
else:
openai.api_base = API_BASE + '/openai/v1'
except:
traceback.print_exc()
openai.api_base = API_BASE + '/openai/v1'
if openai.api_base != previous:
print('Set primary model to', openai.api_base)
time.sleep(10)
if PRIMARY_MODEL_CHOICE:
t = Thread(target=background)
t.daemon = True
t.start()
print('Started the background thread.')
# A system prompt can be injected into the very first spot in the context.
# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE,
# the content in CONTEXT_TRIGGER_INJECTION will be injected.
# Setting CONTEXT_TRIGGER_PHRASE will also add it to the selectable examples.
CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE')
CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION')
openai.api_key = 'null'
openai.api_base = API_BASE + '/openai/v1'
def stream_response(prompt, history):
messages = []
do_injection = False
for human, assistant in history:
messages.append({'role': 'user', 'content': str(human)})
messages.append({'role': 'assistant', 'content': str(assistant)})
if CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in human:
do_injection = True
messages.append({'role': 'user', 'content': prompt})
if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt):
messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION})
try:
response = openai.ChatCompletion.create(
model='0',
messages=messages,
temperature=0,
max_tokens=300,
stream=True,
headers={'LLM-Source': 'huggingface-demo'}
)
except Exception:
raise gr.Error("Failed to reach inference endpoint.")
message = ''
for chunk in response:
if len(chunk['choices'][0]['delta']) != 0:
message += chunk['choices'][0]['delta']['content']
yield message
examples = ["hello"]
if CONTEXT_TRIGGER_PHRASE:
examples.insert(0, CONTEXT_TRIGGER_PHRASE)
with gr.Blocks(analytics_enabled=False) as demo:
gr.ChatInterface(stream_response, examples=examples, title=APP_TITLE, analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}')
if TRACKING_CODE:
print('Inserting tracking code')
gr.HTML(TRACKING_CODE)
demo.queue(concurrency_count=1, api_open=False).launch(show_api=False)

View File

@ -0,0 +1,3 @@
gradio
openai
requests

View File

@ -1,3 +1,8 @@
"""
This file is used to run certain tasks when the HTTP server starts.
It's located here so it doesn't get imported with daemon.py
"""
try: try:
import gevent.monkey import gevent.monkey

38
other/nginx-site.conf Normal file
View File

@ -0,0 +1,38 @@
server
{
listen 443 ssl http2 default_server;
server_name _;
proxy_set_header Host $host;
proxy_set_header Connection $http_connection;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Scheme $scheme;
location ~* ^/api/(.*?|v1|openai)/(v1|(generate|stream)|(chat/completions|completions))$
{
# Route to inference endpoints
proxy_pass http://127.0.0.1:5000;
# Required for streaming (both websockets and SSE).
proxy_buffering off;
proxy_cache off;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
# Set long timeouts for inference operations.
# Cloudflare has a timeout of 100 seconds.
proxy_read_timeout 120;
proxy_connect_timeout 120;
proxy_send_timeout 120;
}
location /
{
proxy_pass http://127.0.0.1:5000;
}
ssl_certificate /etc/ssl/certs/nginx-selfsigned.crt;
ssl_certificate_key /etc/ssl/private/nginx-selfsigned.key;
include /etc/nginx/snippets/ssl-params.conf;
}

11
other/tests/config.sh Normal file
View File

@ -0,0 +1,11 @@
HOST="proxy.chub-archive.evulid.cc"
AUTH_KEY="TEST_1df979f0-6df1-41bd-814a-e99b1680e727"
PROXY_SERVERS=(
"http://172.0.4.7:3128"
"http://172.0.4.8:3128"
"http://172.0.4.10:3128"
"http://172.0.4.12:3128"
"http://172.0.4.13:3128"
)

58
other/tests/generate.sh Executable file
View File

@ -0,0 +1,58 @@
#!/bin/bash
SLEEP_TIME=2
while getopts p:t: flag; do
case "${flag}" in
p) PROXY_CHOICE=${OPTARG} ;;
t) SLEEP_TIME=${OPTARG} ;;
*) ;;
esac
done
SOURCE=${BASH_SOURCE[0]}
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
SOURCE=$(readlink "$SOURCE")
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
source "$DIR/config.sh"
if [ -n "$PROXY_CHOICE" ]; then
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
echo "Using $our_proxy_server"
else
our_proxy_server=""
fi
while true; do
echo "--> START <--"
DATA=$(
cat <<EOF
{
"prompt": "Write a 300 word story about an apple tree.",
"temperature": 1,
"max_new_tokens": 100,
"top_p": 1.0,
"top_k": -1,
"use_beam_search": false,
"stop": ["TEST"],
"ignore_eos": false,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"length_penalty": 1.0,
"early_stopping": false
}
EOF
)
curl "https://$HOST/api/v1/generate" -m 100 -x "$our_proxy_server" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $AUTH_KEY" \
-d "$DATA"
echo -e "--> DONE <--\n"
sleep $SLEEP_TIME
done

View File

@ -0,0 +1,52 @@
#!/bin/bash
DO_STREAM=false
SLEEP_TIME=2
while getopts p:t:s flag; do
case "${flag}" in
s) DO_STREAM=true ;;
p) PROXY_CHOICE=${OPTARG} ;;
t) SLEEP_TIME=${OPTARG} ;;
*) ;;
esac
done
SOURCE=${BASH_SOURCE[0]}
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
SOURCE=$(readlink "$SOURCE")
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
source "$DIR/config.sh"
if [ ! -z "$PROXY_CHOICE" ]; then
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
echo "Using $our_proxy_server"
else
our_proxy_server=""
fi
while true; do
echo "--> START <--"
DATA=$(
cat <<EOF
{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Write a 300 word story about an apple tree."}],
"max_tokens": 100,
"stream": $DO_STREAM
}
EOF
)
curl "https://$HOST/api/openai/v1/chat/completions" -m 100 -x "$our_proxy_server" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $AUTH_KEY" \
-d "$DATA"
echo -e "--> DONE <--\n"
sleep $SLEEP_TIME
done

52
other/tests/oai-completion.sh Executable file
View File

@ -0,0 +1,52 @@
#!/bin/bash
DO_STREAM=false
SLEEP_TIME=2
while getopts p:t:s flag; do
case "${flag}" in
s) DO_STREAM=true ;;
p) PROXY_CHOICE=${OPTARG} ;;
t) SLEEP_TIME=${OPTARG} ;;
*) ;;
esac
done
SOURCE=${BASH_SOURCE[0]}
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
SOURCE=$(readlink "$SOURCE")
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
source "$DIR/config.sh"
if [ ! -z "$PROXY_CHOICE" ]; then
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
echo "Using $our_proxy_server"
else
our_proxy_server=""
fi
while true; do
echo "--> START <--"
DATA=$(
cat <<EOF
{
"model": "gpt-4",
"prompt": "Write a 300 word story about an apple tree.",
"max_tokens": 100,
"stream": $DO_STREAM
}
EOF
)
curl "https://$HOST/api/openai/v1/completions" -m 100 -x "$our_proxy_server" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $AUTH_KEY" \
-d "$DATA"
echo -e "--> DONE <--\n"
sleep $SLEEP_TIME
done

64
other/tests/start-bulk.sh Executable file
View File

@ -0,0 +1,64 @@
#!/bin/bash
# Function to display help message
function display_help {
echo "Usage: $0 -n num_windows -c command"
echo
echo " -n, --number Number of windows to create"
echo " -c, --command Command to run in each window"
echo
exit 1
}
# Parse command line arguments
while getopts "n:c:h" opt; do
case ${opt} in
n)
num_windows=${OPTARG}
;;
c)
command=${OPTARG}
;;
h)
display_help
;;
\?)
echo "Invalid option: -$OPTARG" 1>&2
display_help
;;
:)
echo "Option -$OPTARG requires an argument." 1>&2
display_help
;;
esac
done
# Check if number of windows and command are provided
if [ -z "$num_windows" ] || [ -z "$command" ]; then
echo "Both number of windows and command are required."
display_help
fi
# Calculate rows and columns
rows=$(echo "sqrt($num_windows)" | bc)
columns=$(echo "($num_windows + $rows - 1) / $rows" | bc)
# Create a new tmux session
tmux new-session -d -s llm_tester "$command -p 0"
# Create the remaining windows
for ((i = 1; i < $num_windows; i++)); do
if ((i % $columns == 0)); then
tmux select-layout -t llm_tester:0 tiled
tmux select-pane -t 0
tmux split-window -t llm_tester:0 -v "$command -p $i"
else
tmux split-window -t llm_tester:0 -h "$command -p $i"
fi
done
# Balance the windows
tmux select-layout -t llm_tester:0 tiled
# Attach to the session
tmux attach-session -t llm_tester

62
other/ooba-test-streaming.py → other/tests/stream.py Normal file → Executable file
View File

@ -1,37 +1,50 @@
import asyncio import asyncio
import json import json
import sys import sys
import os
import time
from pathlib import Path
try: try:
import websockets import websockets
except ImportError: except ImportError:
print("Websockets package not found. Make sure it's installed.") print("Websockets package not found. Make sure it's installed.")
# For local streaming, the websockets are hosted without ssl - ws:// script_path = os.path.dirname(os.path.realpath(__file__))
HOST = 'localhost:5000'
URI = f'ws://{HOST}/api/v1/stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' def parse_bash_config(file_path):
config = {}
with open(file_path, 'r') as f:
for line in f:
if line.startswith('#') or '=' not in line:
continue
key, value = line.strip().split('=', 1)
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
elif value.startswith('(') and value.endswith(')'):
value = value[1:-1].split()
value = [v.strip('"') for v in value]
config[key] = value
return config
config = parse_bash_config(Path(script_path, 'config.sh'))
async def run(context): async def run(context):
# Note: the selected defaults change from time to time.
request = { request = {
'prompt': context, 'prompt': context,
'max_new_tokens': 250, 'max_new_tokens': 250,
'auto_max_new_tokens': False, 'auto_max_new_tokens': False,
'max_tokens_second': 0, 'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None', 'preset': 'None',
'do_sample': True, 'do_sample': True,
'temperature': 0.7, 'temperature': 0.7,
'top_p': 0.1, 'top_p': 0.1,
'typical_p': 1, 'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4 'epsilon_cutoff': 0,
'eta_cutoff': 0, # In units of 1e-4 'eta_cutoff': 0,
'tfs': 1, 'tfs': 1,
'top_a': 0, 'top_a': 0,
'repetition_penalty': 1.18, 'repetition_penalty': 1.18,
@ -48,7 +61,6 @@ async def run(context):
'mirostat_eta': 0.1, 'mirostat_eta': 0.1,
'guidance_scale': 1, 'guidance_scale': 1,
'negative_prompt': '', 'negative_prompt': '',
'seed': -1, 'seed': -1,
'add_bos_token': True, 'add_bos_token': True,
'truncation_length': 2048, 'truncation_length': 2048,
@ -58,7 +70,7 @@ async def run(context):
'stopping_strings': [] 'stopping_strings': []
} }
async with websockets.connect(URI, ping_interval=None) as websocket: async with websockets.connect(f'wss://{config["HOST"]}/api/v1/stream', ping_interval=None) as websocket:
await websocket.send(json.dumps(request)) await websocket.send(json.dumps(request))
yield context # Remove this if you just want to see the reply yield context # Remove this if you just want to see the reply
@ -67,20 +79,28 @@ async def run(context):
incoming_data = await websocket.recv() incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data) incoming_data = json.loads(incoming_data)
print(incoming_data)
match incoming_data['event']: match incoming_data['event']:
case 'text_stream': # case 'text_stream':
yield incoming_data['text'] # yield incoming_data['text']
case 'stream_end': case 'stream_end':
return return
async def print_response_stream(prompt): async def print_response_stream(prompt):
async for response in run(prompt): try:
print(response, end='') async for response in run(prompt):
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. print(response, end='')
print('\n\nfinished') sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
except Exception as e:
print(e)
if __name__ == '__main__': if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)" prompt = "Write a 300 word story about an apple tree.\n\n"
asyncio.run(print_response_stream(prompt)) while True:
print('--> START <--')
asyncio.run(print_response_stream(prompt))
print('--> DONE <--')
time.sleep(2)

View File

@ -1,21 +1,18 @@
flask~=2.3.3 flask~=2.3.3
flask_cors
pyyaml~=6.0.1 pyyaml~=6.0.1
flask_caching Flask-Caching==2.0.2
requests~=2.31.0 requests~=2.31.0
tiktoken~=0.5.0 tiktoken~=0.5.0
gunicorn
gevent~=23.9.0.post1 gevent~=23.9.0.post1
async-timeout
flask-sock
uvicorn~=0.23.2
fastapi~=0.103.1
torch~=2.0.1
PyMySQL~=1.1.0 PyMySQL~=1.1.0
DBUtils~=3.0.3
simplejson~=3.19.1 simplejson~=3.19.1
websockets~=11.0.3 websockets~=11.0.3
basicauth~=1.0.0 basicauth~=1.0.0
openai~=0.28.0 openai~=0.28.0
urllib3~=2.0.4 flask-sock==0.6.0
celery[redis] gunicorn==21.2.0
redis==5.0.1
ujson==5.8.0
vllm==0.2.1.post1
gradio~=3.46.1
coloredlogs~=15.0.1

138
server.py
View File

@ -1,5 +1,3 @@
from llm_server.config.config import mode_ui_names
try: try:
import gevent.monkey import gevent.monkey
@ -7,37 +5,46 @@ try:
except ImportError: except ImportError:
pass pass
from llm_server.pre_fork import server_startup
from llm_server.config.load import load_config
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
import simplejson as json import simplejson as json
from flask import Flask, jsonify, render_template, request from flask import Flask, jsonify, render_template, request, Response
import llm_server import config
from llm_server import opts
from llm_server.cluster.backend import get_model_choices
from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.config import mode_ui_names
from llm_server.config.load import load_config
from llm_server.custom_redis import flask_cache, redis
from llm_server.database.conn import database from llm_server.database.conn import database
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.llm import get_token_count from llm_server.helpers import auto_set_base_client_api
from llm_server.routes.openai import openai_bp from llm_server.llm.vllm.info import vllm_info
from llm_server.pre_fork import server_startup
from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.sock import init_wssocket
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. # TODO: seperate queue item timeout for websockets (make longer, like 5 minutes)
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail # TODO: return an `error: True`, error code, and error message rather than just a formatted message
# TODO: implement background thread to test backends via sending test prompts # TODO: what happens when all backends are offline? What about the "online" key in the stats page?
# TODO: if backend fails request, mark it as down # TODO: redis SCAN vs KEYS??
# TODO: allow setting concurrent gens per-backend # TODO: is frequency penalty the same as ooba repetition penalty???
# TODO: set the max tokens to that of the lowest backend # TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
# TODO: implement RRD backend loadbalancer option
# TODO: have VLLM reject a request if it already has n == concurrent_gens running
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: use coloredlogs
# TODO: need to update opts. for workers
# Lower priority # Lower priority
# TODO: if a backend is at its limit of concurrent requests, choose a different one
# TODO: make error messages consitient
# TODO: support logit_bias on OpenAI and Ooba endpoints.
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: validate openai_silent_trim works as expected and only when enabled
# TODO: rewrite config storage. Store in redis so we can reload it.
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
# TODO: the estiamted wait time lags behind the stats # TODO: the estiamted wait time lags behind the stats
# TODO: simulate OpenAI error messages regardless of endpoint # TODO: simulate OpenAI error messages regardless of endpoint
@ -59,19 +66,16 @@ except ModuleNotFoundError as e:
print('Please see README.md for install instructions.') print('Please see README.md for install instructions.')
sys.exit(1) sys.exit(1)
import config
from llm_server import opts
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import RedisWrapper, flask_cache
from llm_server.llm import redis
from llm_server.routes.stats import get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats
app = Flask(__name__) app = Flask(__name__)
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/') # Fixes ConcurrentObjectUseError
# https://github.com/miguelgrinberg/simple-websocket/issues/24
app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
app.register_blueprint(bp, url_prefix='/api/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
init_wssocket(app)
flask_cache.init_app(app) flask_cache.init_app(app)
flask_cache.clear() flask_cache.clear()
@ -82,18 +86,13 @@ if config_path_environ:
else: else:
config_path = Path(script_path, 'config', 'config.yml') config_path = Path(script_path, 'config', 'config.yml')
success, config, msg = load_config(config_path, script_path) success, config, msg = load_config(config_path)
if not success: if not success:
print('Failed to load config:', msg) print('Failed to load config:', msg)
sys.exit(1) sys.exit(1)
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db() create_db()
llm_server.llm.redis = RedisWrapper('local_llm')
create_db()
# print(app.url_map)
@app.route('/') @app.route('/')
@ -101,20 +100,30 @@ create_db()
@app.route('/api/openai') @app.route('/api/openai')
@flask_cache.cached(timeout=10) @flask_cache.cached(timeout=10)
def home(): def home():
base_client_api = redis.get('base_client_api', dtype=str)
stats = generate_stats() stats = generate_stats()
model_choices, default_model = get_model_choices()
if not stats['online']: if default_model:
running_model = estimated_wait_sec = 'offline' if not model_choices.get(default_model):
else: return 'The server is still starting up. Please wait...'
running_model = redis.get('running_model', str, 'ERROR')
active_gen_workers = get_active_gen_workers() default_model_info = model_choices[default_model]
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
if default_model_info['queued'] == 0 and default_model_info['queued'] >= default_model_info['concurrent_gens']:
# There will be a wait if the queue is empty but prompts are processing, but we don't # There will be a wait if the queue is empty but prompts are processing, but we don't
# know how long. # know how long.
estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds" default_estimated_wait_sec = f"less than {int(default_model_info['estimated_wait'])} seconds"
else: else:
estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds" default_estimated_wait_sec = f"{int(default_model_info['estimated_wait'])} seconds"
else:
default_model_info = {
'model': 'OFFLINE',
'processing': '-',
'queued': '-',
'context_size': '-',
}
default_estimated_wait_sec = 'OFFLINE'
if len(config['analytics_tracking_code']): if len(config['analytics_tracking_code']):
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>" analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
@ -127,32 +136,47 @@ def home():
info_html = '' info_html = ''
mode_info = '' mode_info = ''
if opts.mode == 'vllm': for k, v in cluster_config.all().items():
mode_info = vllm_info if v['mode'] == 'vllm':
mode_info = vllm_info
base_client_api = redis.get('base_client_api', str) break
return render_template('home.html', return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name, llm_middleware_name=opts.llm_middleware_name,
analytics_tracking_code=analytics_tracking_code, analytics_tracking_code=analytics_tracking_code,
info_html=info_html, info_html=info_html,
current_model=opts.manual_model_name if opts.manual_model_name else running_model, default_model=default_model_info['model'],
default_active_gen_workers=default_model_info['processing'],
default_proompters_in_queue=default_model_info['queued'],
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
client_api=f'https://{base_client_api}', client_api=f'https://{base_client_api}',
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None, ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled',
estimated_wait=estimated_wait_sec, default_estimated_wait=default_estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0], mode_name=mode_ui_names[opts.frontend_api_mode][0],
api_input_textbox=mode_ui_names[opts.mode][1], api_input_textbox=mode_ui_names[opts.frontend_api_mode][1],
streaming_input_textbox=mode_ui_names[opts.mode][2], streaming_input_textbox=mode_ui_names[opts.frontend_api_mode][2],
context_size=opts.context_size, default_context_size=default_model_info['context_size'],
stats_json=json.dumps(stats, indent=4, ensure_ascii=False), stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info, extra_info=mode_info,
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled', openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
expose_openai_system_prompt=opts.expose_openai_system_prompt, expose_openai_system_prompt=opts.expose_openai_system_prompt,
enable_streaming=opts.enable_streaming, enable_streaming=opts.enable_streaming,
model_choices=model_choices,
proompters_5_min=stats['stats']['proompters']['5_min'],
proompters_24_hrs=stats['stats']['proompters']['24_hrs'],
) )
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend @app.route('/robots.txt')
def robots():
# TODO: have config value to deny all
# TODO: https://developers.google.com/search/docs/crawling-indexing/robots/create-robots-txt
t = """User-agent: *
Allow: /"""
r = Response(t)
r.headers['Content-Type'] = 'text/plain'
return r
@app.route('/<first>') @app.route('/<first>')
@app.route('/<first>/<path:rest>') @app.route('/<first>/<path:rest>')

View File

@ -65,6 +65,19 @@
.hidden { .hidden {
display: none; display: none;
} }
.header-workers {
font-weight: normal;
font-size: 14pt;
}
h3 {
font-size: 16pt;
}
.no-marker {
list-style: none;
}
</style> </style>
</head> </head>
@ -76,8 +89,12 @@
<h1 style="text-align: center;margin-top: 0;">{{ llm_middleware_name }}</h1> <h1 style="text-align: center;margin-top: 0;">{{ llm_middleware_name }}</h1>
<div class="info-box"> <div class="info-box">
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p> <p><strong>Current Model:</strong> <span id="model">{{ default_model }}</span></p>
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p> <p>
<strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ default_estimated_wait }}</span><br>
Processing: {{ default_active_gen_workers }}<br>
Queued: {{ default_proompters_in_queue }}
</p>
<br> <br>
<p><strong>Client API URL:</strong> {{ client_api }}</p> <p><strong>Client API URL:</strong> {{ client_api }}</p>
<p><strong>Streaming API URL:</strong> {{ ws_client_api if enable_streaming else 'Disabled' }}</p> <p><strong>Streaming API URL:</strong> {{ ws_client_api if enable_streaming else 'Disabled' }}</p>
@ -91,17 +108,20 @@
<br> <br>
<div class="info-box"> <div class="info-box">
<div id="oobabooga"> <h3>Instructions</h3>
<strong>Instructions:</strong> <div id="instructions">
<ol> <ol>
<li>In Settings > Power User Options, enable <kbd>Relaxed API URLS</kbd>.</li>
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li> <li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li> <li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
{% if enable_streaming %}<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>{% endif %} {% if enable_streaming %}
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
{% endif %}
<li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer <li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
API key</kbd> textbox. API key</kbd> textbox.
</li> </li>
<li>Click <kbd>Connect</kbd> to test the connection.</li> <li>Click <kbd>Connect</kbd> to test the connection.</li>
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li> <li>Open your preset config and set <kbd>Context Size</kbd> to {{ default_context_size }}.</li>
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a> <li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>
</li> </li>
</ol> </ol>
@ -120,13 +140,45 @@
<br> <br>
<div class="info-box"> <div class="info-box">
<pre><code class="language-json" style="background-color: white">{{ stats_json|safe }}</code></pre> <h3>Statistics</h3>
Proompters:
<ul style="margin-top: 5px;">
<li class="no-marker">5 minutes: {{ proompters_5_min }}</li>
<li class="no-marker">24 hours: {{ proompters_24_hrs }}</li>
</ul>
</div> </div>
<br>
{% for key, value in model_choices.items() %}
<div class="info-box">
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}worker{% else %}workers{% endif %}</span></h3>
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
{% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %}
{% else %}
{% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %}
{% endif %}
<p>
<strong>Estimated Wait Time:</strong> {{ estimated_wait_sec }}<br>
Processing: {{ value.processing }}<br>
Queued: {{ value.queued }}<br>
</p>
<p>
<strong>Client API URL:</strong> {{ value.client_api }}<br>
<strong>Streaming API URL:</strong> {{ value.ws_client_api }}<br>
<strong>OpenAI-Compatible API URL:</strong> {{ value.openai_client_api }}
</p>
<p><strong>Context Size:</strong> {{ value.context_size }}</p>
<p><strong>Average Generation Time:</strong> {{ value.avg_generation_time | int }} seconds</p>
</div>
<br>
{% endfor %}
</div> </div>
<div class="footer"> <div class="footer">
<a href="https://git.evulid.cc/cyberes/local-llm-server" target="_blank">git.evulid.cc/cyberes/local-llm-server</a> <a href="https://git.evulid.cc/cyberes/local-llm-server" target="_blank">git.evulid.cc/cyberes/local-llm-server</a>
</div> </div>
<script>hljs.highlightAll();</script>
</body> </body>
</html> </html>