Merge cluster to master (#3)

Co-authored-by: Cyberes <cyberes@evulid.cc>
Reviewed-on: #3
This commit is contained in:
Cyberes 2023-10-27 19:19:22 -06:00
parent 561820fb9e
commit 0059e7956c
80 changed files with 2911 additions and 1258 deletions

View File

@ -43,7 +43,9 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
### Use
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database.
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
### To Do

View File

@ -1,22 +1,19 @@
import time
from llm_server.routes.cache import redis
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import argparse
import logging
import os
import sys
import time
from pathlib import Path
from llm_server.config.load import load_config
from llm_server.database.create import create_db
from redis import Redis
from llm_server.workers.app import start_background
from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.load import load_config, parse_backends
from llm_server.custom_redis import redis
from llm_server.database.create import create_db
from llm_server.logging import create_logger, logging_info, init_logging
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
@ -26,19 +23,46 @@ else:
config_path = Path(script_path, 'config', 'config.yml')
if __name__ == "__main__":
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
parser = argparse.ArgumentParser(description='Daemon microservice.')
parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.")
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
args = parser.parse_args()
success, config, msg = load_config(config_path, script_path)
# TODO: have this be set by either the arg or a config value
if args.debug:
logging_info.level = logging.DEBUG
init_logging()
logger = create_logger('daemon')
logger.debug('Debug logging enabled.')
if not args.no_reset:
Redis().flushall()
logger.info('Flushed Redis.')
success, config, msg = load_config(config_path)
if not success:
print('Failed to load config:', msg)
logger.info(f'Failed to load config: {msg}')
sys.exit(1)
create_db()
cluster_config.clear()
cluster_config.load(parse_backends(config))
logger.info('Loading backend stats...')
generate_stats(regen=True)
start_background()
redis.set('daemon_started', 1)
print('== Daemon Setup Complete ==\n')
# Give some time for the background threads to get themselves ready to go.
time.sleep(2)
while True:
time.sleep(3600)
redis.set('daemon_started', 1)
logger.info('== Daemon Setup Complete ==')
try:
while True:
time.sleep(3600)
except KeyboardInterrupt:
redis.set('daemon_started', 0)

View File

View File

@ -0,0 +1,117 @@
import numpy as np
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
from llm_server.cluster.stores import redis_running_models
from llm_server.custom_redis import redis
from llm_server.llm.generator import generator
from llm_server.llm.info import get_info
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model
def get_backends_from_model(model_name: str):
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
def get_running_models():
return redis_running_models.keys()
def purge_backend_from_running_models(backend_url: str):
keys = redis_running_models.keys()
pipeline = redis_running_models.pipeline()
for model in keys:
pipeline.srem(model, backend_url)
pipeline.execute()
def is_valid_model(model_name: str):
return redis_running_models.exists(model_name)
def test_backend(backend_url: str, test_prompt: bool = False):
backend_info = cluster_config.get_backend(backend_url)
if test_prompt:
handler = VLLMBackend(backend_url)
parameters, _ = handler.get_parameters({
"stream": False,
"temperature": 0,
"max_new_tokens": 3,
})
data = {
'prompt': 'test prompt',
**parameters
}
try:
success, response, err = generator(data, backend_url, timeout=10)
if not success or not response or err:
return False, {}
except:
return False, {}
i = get_info(backend_url, backend_info['mode'])
if not i.get('model'):
return False, {}
return True, i
def get_model_choices(regen: bool = False):
if not regen:
c = redis.getp('model_choices')
if c:
return c
base_client_api = redis.get('base_client_api', dtype=str)
running_models = get_running_models()
model_choices = {}
for model in running_models:
b = get_backends_from_model(model)
context_size = []
avg_gen_per_worker = []
concurrent_gens = 0
for backend_url in b:
backend_info = cluster_config.get_backend(backend_url)
if backend_info.get('model_config'):
context_size.append(backend_info['model_config']['max_position_embeddings'])
if backend_info.get('average_generation_elapsed_sec'):
avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec'])
concurrent_gens += backend_info['concurrent_gens']
active_gen_workers = get_active_gen_workers_model(model)
proompters_in_queue = priority_queue.len(model)
if len(avg_gen_per_worker):
average_generation_elapsed_sec = np.average(avg_gen_per_worker)
else:
average_generation_elapsed_sec = 0
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, concurrent_gens, active_gen_workers)
model_choices[model] = {
'model': model,
'client_api': f'https://{base_client_api}/{model}',
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled',
'backend_count': len(b),
'estimated_wait': estimated_wait_sec,
'queued': proompters_in_queue,
'processing': active_gen_workers,
'avg_generation_time': average_generation_elapsed_sec,
'concurrent_gens': concurrent_gens
}
if len(context_size):
model_choices[model]['context_size'] = min(context_size)
# Python wants to sort lowercase vs. uppercase letters differently.
model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper()))
default_backend_url = get_a_cluster_backend()
default_backend_info = cluster_config.get_backend(default_backend_url)
if not default_backend_info.get('model'):
return {}, None
default_model = default_backend_info['model']
redis.setp('model_choices', (model_choices, default_model))
return model_choices, default_model

View File

@ -0,0 +1,124 @@
import hashlib
import pickle
import traceback
from llm_server import opts
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
from llm_server.cluster.stores import redis_running_models
from llm_server.custom_redis import RedisCustom
from llm_server.routes.helpers.model import estimate_model_size
class RedisClusterStore:
def __init__(self, name: str, **kwargs):
self.name = name
self.config_redis = RedisCustom(name, **kwargs)
def clear(self):
self.config_redis.flush()
def load(self, config: dict):
for k, v in config.items():
self.add_backend(k, v)
def add_backend(self, name: str, values: dict):
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
self.set_backend_value(name, 'online', False)
h = hashlib.sha256(name.encode('utf-8')).hexdigest()
self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}')
def set_backend_value(self, backend: str, key: str, value):
# By storing the value as a pickle we don't have to cast anything when getting the value from Redis.
self.config_redis.hset(backend, key, pickle.dumps(value))
def get_backend(self, name: str):
r = self.config_redis.hgetall(name)
output = {}
for k, v in r.items():
output[k.decode('utf8')] = pickle.loads(v)
return output
def all(self):
keys = self.config_redis.keys('*')
if keys:
result = {}
for key in keys:
if key != f'{self.name}:____':
v = self.get_backend(key)
result[key] = v
return result
else:
return {}
def validate_backend(self, backend_url: str):
"""
Returns the backend URL that was given, or a new one if that was offline.
:param backend_url:
:return:
"""
backend_info = self.get_backend(backend_url)
if not backend_info['online']:
old = backend_url
backend_url = get_a_cluster_backend()
print(f'Backend {old} offline. Request was redirected to {backend_url}')
return backend_url
cluster_config = RedisClusterStore('cluster_config')
def get_backends():
backends = cluster_config.all()
result = {}
for k, v in backends.items():
b = cluster_config.get_backend(k)
status = b.get('online', False)
priority = b['priority']
result[k] = {'status': status, 'priority': priority}
try:
if not opts.prioritize_by_size:
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: -kv[1]['priority'],
reverse=True
)
else:
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: estimate_model_size(kv[1]['model_config']),
reverse=True
)
offline_backends = sorted(
((url, info) for url, info in backends.items() if not info['online']),
key=lambda kv: -kv[1]['priority'],
reverse=True
)
return [url for url, info in online_backends], [url for url, info in offline_backends]
except KeyError:
traceback.print_exc()
print(backends)
def get_a_cluster_backend(model=None):
"""
Get a backend from Redis. If there are no online backends, return None.
If `model` is not supplied, we will pick one ourself.
"""
if model:
# First, determine if there are multiple backends hosting the same model.
backends_hosting_model = [i.decode('utf-8') for i in redis_running_models.smembers(model)]
# If so, create an iterator for those backends
if len(backends_hosting_model):
add_backend_cycler(model, backends_hosting_model)
cycled = redis_cycle(model)
if len(cycled):
return cycled[0]
else:
# No backend hosting that model
return None
else:
online, _ = get_backends()
if len(online):
return online[0]

View File

@ -0,0 +1,39 @@
import redis
redis_cycler_db = redis.Redis(host='localhost', port=6379, db=9)
def redis_cycle(list_name):
"""
Emulates itertools.cycle() but returns the complete shuffled list.
:param list_name:
:return:
"""
pipeline = redis_cycler_db.pipeline()
pipeline.lpop(list_name)
to_move = pipeline.execute()[0]
if not to_move:
return []
pipeline.rpush(list_name, to_move)
pipeline.lrange(list_name, 0, -1)
results = pipeline.execute()
new_list = results[-1]
return [x.decode('utf-8') for x in new_list]
def add_backend_cycler(list_name: str, new_elements: list):
existing_elements = [i.decode('utf-8') for i in redis_cycler_db.lrange(list_name, 0, -1)]
existing_set = set(existing_elements)
with redis_cycler_db.pipeline() as pipe:
# Add elements
for element in new_elements:
if element not in existing_set:
pipe.rpush(list_name, element)
# Remove elements
for element in existing_set:
if element not in new_elements:
pipe.lrem(list_name, 0, element)
pipe.execute()

View File

@ -0,0 +1,3 @@
from llm_server.custom_redis import RedisCustom
redis_running_models = RedisCustom('running_models')

View File

@ -0,0 +1,38 @@
import time
from threading import Thread
from llm_server.cluster.backend import test_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.stores import redis_running_models
def cluster_worker():
counter = 0
while True:
test_prompt = False
if counter % 4 == 0:
# Only send a test prompt every 120 seconds.
test_prompt = True
threads = []
for n, v in cluster_config.all().items():
thread = Thread(target=check_backend, args=(n, v, test_prompt))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
time.sleep(15)
counter += 1
def check_backend(n, v, test_prompt):
online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
if online:
running_model = backend_info['model']
for k, v in backend_info.items():
cluster_config.set_backend_value(n, k, v)
redis_running_models.sadd(running_model, n)
else:
for model in redis_running_models.keys():
redis_running_models.srem(model, n)
cluster_config.set_backend_value(n, 'online', online)

View File

@ -28,16 +28,19 @@ config_default_vars = {
'openai_force_no_hashes': True,
'include_system_tokens_in_stats': True,
'openai_moderation_scan_last_n': 5,
'openai_moderation_workers': 10,
'openai_org_name': 'OpenAI',
'openai_silent_trim': False,
'openai_moderation_enabled': True,
'netdata_root': None
'netdata_root': None,
'show_backends': True,
'background_homepage_cacher': True,
'openai_moderation_timeout': 5,
'prioritize_by_size': False
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name']
mode_ui_names = {
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
}

View File

@ -3,38 +3,28 @@ import sys
import openai
import llm_server
from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
from llm_server.custom_redis import redis
from llm_server.database.conn import database
from llm_server.database.database import get_number_of_rows
from llm_server.helpers import resolve_path
from llm_server.routes.cache import redis
from llm_server.routes.queue import PriorityQueue
def load_config(config_path, script_path):
def load_config(config_path):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
success, config, msg = config_loader.load_config()
if not success:
return success, config, msg
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
# TODO: this is atrocious
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client']
opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.cluster = config['cluster']
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
@ -53,10 +43,20 @@ def load_config(config_path, script_path):
opts.openai_force_no_hashes = config['openai_force_no_hashes']
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
opts.openai_moderation_workers = config['openai_moderation_workers']
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
opts.show_backends = config['show_backends']
opts.background_homepage_cacher = config['background_homepage_cacher']
opts.openai_moderation_timeout = config['openai_moderation_timeout']
opts.frontend_api_mode = config['frontend_api_mode']
opts.prioritize_by_size = config['prioritize_by_size']
# Scale the number of workers.
for item in config['cluster']:
opts.cluster_workers += item['concurrent_gens']
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
@ -78,6 +78,16 @@ def load_config(config_path, script_path):
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
redis.set('backend_mode', opts.mode)
return success, config, msg
def parse_backends(config):
if not config.get('cluster'):
return False
cluster = config.get('cluster')
config = {}
for item in cluster:
backend_url = item['backend_url'].strip('/')
item['backend_url'] = backend_url
config[backend_url] = item
return config

View File

@ -0,0 +1,3 @@
from llm_server.custom_redis import RedisCustom
redis_config = RedisCustom('redis_config')

View File

@ -1,24 +1,27 @@
import pickle
import sys
import traceback
from typing import Callable, List, Mapping, Union
from typing import Callable, List, Mapping, Optional, Union
import redis as redis_pkg
import simplejson as json
from flask_caching import Cache
from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000
class RedisWrapper:
class RedisCustom(Redis):
"""
A wrapper class to set prefixes to keys.
A simple wrapper class for Redis to create a "namespace" within a DB,
which simplyifies key management.
"""
def __init__(self, prefix, **kwargs):
super().__init__()
self.redis = Redis(**kwargs)
self.prefix = prefix
try:
@ -34,12 +37,11 @@ class RedisWrapper:
def set(self, key, value, ex: Union[ExpiryT, None] = None):
return self.redis.set(self._key(key), value, ex=ex)
def get(self, key, dtype=None, default=None):
"""
:param key:
:param dtype: convert to this type
:return:
"""
def get(self, key, default=None, dtype=None):
# TODO: use pickle
import inspect
if inspect.isclass(default):
raise Exception
d = self.redis.get(self._key(key))
if dtype and d:
@ -108,7 +110,10 @@ class RedisWrapper:
):
return self.redis.hincrby(self._key(name), key, amount)
def hdel(self, name: str, *keys: List):
def zcard(self, name: KeyT):
return self.redis.zcard(self._key(name))
def hdel(self, name: str, *keys: str):
return self.redis.hdel(self._key(name), *keys)
def hget(
@ -129,9 +134,62 @@ class RedisWrapper:
):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
def lpush(self, name: str, *values: FieldT):
return self.redis.lpush(self._key(name), *values)
def hset(
self,
name: str,
key: Optional = None,
value=None,
mapping: Optional[dict] = None,
items: Optional[list] = None,
):
return self.redis.hset(self._key(name), key, value, mapping, items)
def hkeys(self, name: str):
return self.redis.hkeys(self._key(name))
def hmget(self, name: str, keys: List, *args: List):
return self.redis.hmget(self._key(name), keys, *args)
def hgetall(self, name: str):
return self.redis.hgetall(self._key(name))
def keys(self, pattern: PatternT = "*", **kwargs):
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
keys = []
for key in raw_keys:
p = key.decode('utf-8').split(':')
if len(p) >= 2:
# Delete prefix
del p[0]
k = ':'.join(p)
if k != '____':
keys.append(k)
return keys
def pipeline(self, transaction=True, shard_hint=None):
return self.redis.pipeline(transaction, shard_hint)
def smembers(self, name: str):
return self.redis.smembers(self._key(name))
def spop(self, name: str, count: Optional[int] = None):
return self.redis.spop(self._key(name), count)
def rpoplpush(self, src, dst):
return self.redis.rpoplpush(src, dst)
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
return self.redis.zpopmin(self._key(name), count)
def exists(self, *names: KeyT):
n = []
for name in names:
n.append(self._key(name))
return self.redis.exists(*n)
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex)
@ -142,6 +200,15 @@ class RedisWrapper:
else:
return json.loads(r.decode("utf-8"))
def setp(self, name, value):
self.redis.set(self._key(name), pickle.dumps(value))
def getp(self, name: str):
r = self.redis.get(self._key(name))
if r:
return pickle.loads(r)
return r
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):
@ -149,5 +216,40 @@ class RedisWrapper:
self.redis.delete(key)
return flushed
def flushall(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
redis = RedisWrapper('local_llm')
def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
def lrange(self, name: str, start: int, end: int):
return self.redis.lrange(self._key(name), start, end)
def delete(self, *names: KeyT):
return self.redis.delete(*[self._key(i) for i in names])
def lpop(self, name: str, count: Optional[int] = None):
return self.redis.lpop(self._key(name), count)
def zrange(
self,
name: KeyT,
start: int,
end: int,
desc: bool = False,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
byscore: bool = False,
bylex: bool = False,
offset: int = None,
num: int = None,
):
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
def zrem(self, name: KeyT, *values: FieldT):
return self.redis.zrem(self._key(name), *values)
redis = RedisCustom('local_llm')

View File

@ -5,20 +5,20 @@ class DatabaseConnection:
host: str = None
username: str = None
password: str = None
database: str = None
database_name: str = None
def init_db(self, host, username, password, database):
def init_db(self, host, username, password, database_name):
self.host = host
self.username = username
self.password = password
self.database = database
self.database_name = database_name
def cursor(self):
db = pymysql.connect(
host=self.host,
user=self.username,
password=self.password,
database=self.database,
database=self.database_name,
charset='utf8mb4',
autocommit=True,
)

View File

@ -1,15 +1,19 @@
import json
import time
import traceback
from typing import Union
import llm_server
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.database.conn import database
from llm_server.llm.vllm import tokenize
from llm_server.routes.cache import redis
from llm_server.llm import get_token_count
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
# Try not to shove JSON into the database.
if isinstance(response, dict) and response.get('results'):
response = response['results'][0]['text']
try:
@ -19,10 +23,11 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
except:
pass
prompt_tokens = llm_server.llm.get_token_count(prompt)
prompt_tokens = get_token_count(prompt, backend_url)
if not is_error:
if not response_tokens:
response_tokens = llm_server.llm.get_token_count(response)
response_tokens = get_token_count(response, backend_url)
else:
response_tokens = None
@ -43,7 +48,9 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
if token:
increment_token_uses(token)
running_model = redis.get('running_model', str, 'ERROR')
backend_info = cluster_config.get_backend(backend_url)
running_model = backend_info.get('model')
backend_mode = backend_info['mode']
timestamp = int(time.time())
cursor = database.cursor()
try:
@ -52,7 +59,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()
@ -179,3 +186,21 @@ def increment_token_uses(token):
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
finally:
cursor.close()
def get_token_ratelimit(token):
priority = 9990
simultaneous_ip = opts.simultaneous_requests_per_ip
if token:
cursor = database.cursor()
try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
result = cursor.fetchone()
if result:
priority, simultaneous_ip = result
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip

View File

@ -0,0 +1,30 @@
import pickle
from typing import Union
from redis import Redis
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
r = Redis(host='localhost', port=6379, db=3)
data = {
'function': 'log_prompt',
'args': [],
'kwargs': {
'ip': ip,
'token': token,
'prompt': prompt,
'response': response,
'gen_time': gen_time,
'parameters': parameters,
'headers': dict(headers) if headers else headers,
'backend_response_code': backend_response_code,
'request_url': request_url,
'backend_url': backend_url,
'response_tokens': response_tokens,
'is_error': is_error
}
}
r.publish('database-logger', pickle.dumps(data))

View File

@ -8,7 +8,7 @@ import simplejson as json
from flask import make_response
from llm_server import opts
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def resolve_path(*p: str):
@ -54,13 +54,14 @@ def jsonify_pretty(json_dict: Union[list, dict], status=200, indent=4, sort_keys
def round_up_base(n, base):
if base == 0:
print('round_up_base DIVIDE BY ZERO ERROR????', n, base)
# TODO: I don't think passing (0, 0) to this function is a sign of any underlying issues.
# print('round_up_base DIVIDE BY ZERO ERROR????', n, base)
return 0
return math.ceil(n / base) * base
def auto_set_base_client_api(request):
http_host = redis.get('http_host', str)
http_host = redis.get('http_host', dtype=str)
host = request.headers.get("Host")
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
# If the current http_host is not an IP, don't do anything.

View File

@ -1,12 +0,0 @@
import threading
class ThreadSafeInteger:
def __init__(self, value=0):
self.value = value
self._value_lock = threading.Lock()
def increment(self):
with self._value_lock:
self.value += 1
return self.value

View File

@ -1,11 +1,30 @@
import tiktoken
from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import oobabooga, vllm
from llm_server.routes.cache import redis
from llm_server.logging import create_logger
def get_token_count(prompt: str):
backend_mode = redis.get('backend_mode', str)
def fallback_tokenizer(prompt: str):
tokenizer = tiktoken.get_encoding("cl100k_base")
return len(tokenizer.encode(prompt)) + 10
def get_token_count(prompt: str, backend_url: str):
backend_url = cluster_config.validate_backend(backend_url)
if not backend_url:
logger = create_logger('tokenizer')
logger.warning('using fallback tokenizer as there is no valid backend')
return fallback_tokenizer(prompt)
backend_mode = cluster_config.get_backend(backend_url).get('mode')
if not backend_mode:
logger = create_logger('tokenizer')
logger.warning("using fallback tokenizer as the backend isn't initalized")
return fallback_tokenizer(prompt)
if backend_mode == 'vllm':
return vllm.tokenize(prompt)
return vllm.tokenize(prompt, backend_url)
elif backend_mode == 'ooba':
return oobabooga.tokenize(prompt)
else:

View File

@ -1,14 +1,15 @@
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
def generator(request_json_body):
if opts.mode == 'oobabooga':
def generator(request_json_body, cluster_backend, timeout: int = None):
mode = cluster_config.get_backend(cluster_backend)['mode']
if mode == 'ooba':
# from .oobabooga.generate import generate
# return generate(request_json_body)
raise NotImplementedError
elif opts.mode == 'vllm':
elif mode == 'vllm':
from .vllm.generate import generate
r = generate(request_json_body)
return r
return generate(request_json_body, cluster_backend, timeout=timeout)
else:
raise Exception

View File

@ -3,23 +3,35 @@ import requests
from llm_server import opts
def get_running_model():
# TODO: cache the results for 1 min so we don't have to keep calling the backend
# TODO: only use one try/catch
if opts.mode == 'oobabooga':
def get_running_model(backend_url: str, mode: str):
if mode == 'ooba':
try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json()
return r_json['result'], None
except Exception as e:
return False, e
elif opts.mode == 'vllm':
elif mode == 'vllm':
try:
backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json()
return r_json['model'], None
except Exception as e:
return False, e
else:
raise Exception
def get_info(backend_url: str, mode: str):
if mode == 'ooba':
return {}
# raise NotImplementedError
elif mode == 'vllm':
try:
r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout)
j = r.json()
except Exception as e:
return {}
return j
else:
raise Exception

View File

@ -2,14 +2,17 @@ from typing import Tuple, Union
import flask
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
class LLMBackend:
_default_params: dict
def __init__(self, backend_url: str):
self.backend_url = backend_url
self.backend_info = cluster_config.get_backend(self.backend_url)
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError
@ -32,14 +35,16 @@ class LLMBackend:
"""
If a backend needs to do other checks not related to the prompt or parameters.
Default is no extra checks preformed.
:param request:
:param prompt:
:param parameters:
:return:
"""
return True, None
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = get_token_count(prompt)
if prompt_len > opts.context_size - 10:
model_name = redis.get('running_model', str, 'NO MODEL ERROR')
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
prompt_len = get_token_count(prompt, self.backend_url)
token_limit = self.backend_info['model_config']['max_position_embeddings']
if prompt_len > token_limit - 10:
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size'
return True, None

View File

@ -1,78 +1,6 @@
from flask import jsonify
from ..llm_backend import LLMBackend
from ...database.database import log_prompt
from ...helpers import safe_list_get
from ...routes.cache import redis
from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json
class OobaboogaBackend(LLMBackend):
default_params = {}
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError('need to implement default_params')
backend_err = False
response_valid_json, response_json_body = validate_json(response)
if response:
try:
# Be extra careful when getting attributes from the response object
response_status_code = response.status_code
except:
response_status_code = 0
else:
response_status_code = None
# ===============================================
# We encountered an error
if not success or not response or error_msg:
if not error_msg or error_msg == '':
error_msg = 'Unknown error.'
else:
error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, 'error')
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': error_msg,
'results': [{'text': backend_response}]
}), 400
# ===============================================
if response_valid_json:
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
# Ooba doesn't return any error messages so we will just tell the client an error occurred
backend_err = True
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
'error')
response_json_body['results'][0]['text'] = backend_response
if not backend_err:
redis.incr('proompts')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({
**response_json_body
}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 400
def validate_params(self, params_dict: dict):
# No validation required
return True, None
def get_parameters(self, parameters):
del parameters['prompt']
return parameters
def __int__(self):
return

View File

@ -10,7 +10,7 @@ def check_moderation_endpoint(prompt: str):
}
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
if response.status_code != 200:
print(response.text)
print('moderation failed:', response)
response.raise_for_status()
response = response.json()

View File

@ -0,0 +1,97 @@
from flask import jsonify
from llm_server import opts
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
if not request_json_body.get('stop'):
request_json_body['stop'] = []
if not isinstance(request_json_body['stop'], list):
# It is a string, so create a list with the existing element.
request_json_body['stop'] = [request_json_body['stop']]
if stop_hashes:
if opts.openai_force_no_hashes:
request_json_body['stop'].append('###')
else:
# TODO: make stopping strings a configurable
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT'])
else:
request_json_body['stop'].extend(['user:', 'assistant:'])
if request_json_body.get('frequency_penalty', 0) < -2:
request_json_body['frequency_penalty'] = -2
elif request_json_body.get('frequency_penalty', 0) > 2:
request_json_body['frequency_penalty'] = 2
if mode == 'vllm' and request_json_body.get('top_p') == 0:
request_json_body['top_p'] = 0.01
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
if request_json_body['max_tokens'] == 0:
# We don't want to set any defaults here.
del request_json_body['max_tokens']
return request_json_body
def format_oai_err(err_msg):
print('OAI ERROR MESSAGE:', err_msg)
return jsonify({
"error": {
"message": err_msg,
"type": "invalid_request_error",
"param": None,
"code": None
}
}), 400
def validate_oai(parameters):
if parameters.get('messages'):
for m in parameters['messages']:
if m['role'].lower() not in ['assistant', 'user', 'system']:
return format_oai_err('messages role must be assistant, user, or system')
if parameters.get('temperature', 0) > 2:
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
if parameters.get('temperature', 0) < 0:
return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('presence_penalty', 1) > 2:
return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
if parameters.get('presence_penalty', 1) < -2:
return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('max_tokens', 2) < 1:
return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'")
def return_invalid_model_err(requested_model: str):
if requested_model:
msg = f"The model `{requested_model}` does not exist"
else:
msg = "The requested model does not exist"
return jsonify({
"error": {
"message": msg,
"type": "invalid_request_error",
"param": None,
"code": "model_not_found"
}
}), 404

View File

@ -2,86 +2,35 @@ import concurrent.futures
import re
import secrets
import string
import time
import traceback
from typing import Dict, List
import tiktoken
from flask import jsonify, make_response
import llm_server
from llm_server import opts
from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
def build_openai_response(prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
# y = response.split('\n### ')
# if len(y) > 1:
# response = re.sub(r'\n$', '', y[0].strip(' '))
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR')
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
tokenizer = tiktoken.get_encoding("cl100k_base")
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg["content"]))
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
def get_token_count_thread(msg):
return get_token_count(msg["content"], backend_url)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
# If total tokens exceed the limit, start trimming
if total_tokens > context_token_limit:
if total_tokens + formatting_tokens > context_token_limit:
while True:
while total_tokens + formatting_tokens > context_token_limit:
# Calculate the index to start removing messages from
@ -94,22 +43,43 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
break
def get_token_count_thread(msg):
return get_token_count(msg["content"])
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
else:
break
return prompt
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
tokenizer = tiktoken.get_encoding("cl100k_base")
token_count = get_token_count(prompt, backend_url)
# If total tokens exceed the limit, start trimming
if token_count > context_token_limit:
while True:
while token_count > context_token_limit:
# Calculate the index to start removing characters from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
token_count = len(tokenizer.encode(prompt))
if token_count <= context_token_limit or remove_index == len(prompt):
break
token_count = get_token_count(prompt, backend_url)
if token_count > context_token_limit:
# Start over, but this time calculate the token count using the backend
token_count = get_token_count(prompt, backend_url)
else:
break
return prompt
@ -117,8 +87,9 @@ def transform_messages_to_prompt(oai_messages):
try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in oai_messages:
if not msg.get('content') or not msg.get('role'):
if 'content' not in msg.keys() or 'role' not in msg.keys():
return False
msg['content'] = str(msg['content']) # Prevent any weird issues.
if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user':
@ -126,7 +97,7 @@ def transform_messages_to_prompt(oai_messages):
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else:
return False
raise Exception(f'Unknown role: {msg["role"]}')
except Exception as e:
# TODO: use logging
traceback.print_exc()

View File

@ -1,80 +1,21 @@
"""
This file is used by the worker that processes requests.
"""
import json
import time
from uuid import uuid4
import requests
import llm_server
from llm_server import opts
from llm_server.routes.cache import redis
# TODO: make the VLMM backend return TPS and time elapsed
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
def prepare_json(json_data: dict):
# logit_bias is not currently supported
# del json_data['logit_bias']
# Convert back to VLLM.
json_data['max_tokens'] = json_data.pop('max_new_tokens')
return json_data
def transform_to_text(json_request, api_response):
"""
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
:param json_request:
:param api_response:
:return:
"""
prompt = transform_prompt_to_text(json_request['messages'])
text = ''
finish_reason = None
for line in api_response.split('\n'):
if line.startswith('data:'):
try:
data = json.loads(line[5:].strip())
except json.decoder.JSONDecodeError:
break
if 'choices' in data:
for choice in data['choices']:
if 'delta' in choice and 'content' in choice['delta']:
text += choice['delta']['content']
if data['choices'][0]['finish_reason']:
finish_reason = data['choices'][0]['finish_reason']
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text))
running_model = redis.get('running_model', str, 'ERROR')
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
return {
"id": str(uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": running_model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
},
"choices": [
{
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason,
"index": 0
}
]
}
def transform_prompt_to_text(prompt: list):
text = ''
for item in prompt:
@ -82,26 +23,26 @@ def transform_prompt_to_text(prompt: list):
return text.strip('\n')
def handle_blocking_request(json_data: dict):
def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10):
try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
except requests.exceptions.ReadTimeout:
print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
# print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
return False, None, 'Request to backend timed out'
except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False, None, 'Request to backend encountered error'
if r.status_code != 200:
print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
return False, r, f'Backend returned {r.status_code}'
return True, r, None
def generate(json_data: dict):
def generate(json_data: dict, cluster_backend, timeout: int = None):
if json_data.get('stream'):
try:
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False
else:
return handle_blocking_request(json_data)
return handle_blocking_request(json_data, cluster_backend, timeout=timeout)

View File

@ -1,3 +1,7 @@
import requests
from llm_server import opts
vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/vllm-project/vllm" target="_blank">vllm</a> and not all Oobabooga parameters are supported.</p>
<strong>Supported Parameters:</strong>
<ul>
@ -7,4 +11,4 @@ vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="
<li><kbd>max_new_tokens</kbd></li>
<li><kbd>num_beams</kbd> <span style="font-size:9pt">(setting to greater than 1 enables beam search)</span></li>
<li><kbd>ban_eos_token</kbd></li>
</ul>"""
</ul>"""

View File

@ -1,26 +1,51 @@
import concurrent.futures
import requests
import tiktoken
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.logging import create_logger
def tokenize(prompt: str) -> int:
def tokenize(prompt: str, backend_url: str) -> int:
assert backend_url
assert isinstance(backend_url, str)
if not prompt:
# The tokenizers have issues when the prompt is None.
return 0
assert isinstance(prompt, str)
logger = create_logger('tokenizer')
# The backend could have died between when the request was
# submitted and now, so let's double check it's still online.
backend_url = cluster_config.validate_backend(backend_url)
tokenizer = tiktoken.get_encoding("cl100k_base")
# First we tokenize it locally to determine if it's worth sending it to the backend.
initial_estimate = len(tokenizer.encode(prompt))
if initial_estimate <= opts.context_size + 200:
# Split the prompt into 2000 character chunks
chunk_size = 2000
chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)]
# Define a function to send a chunk to the server
def send_chunk(chunk):
try:
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
r = requests.post(f'{backend_url}/tokenize', json={'input': chunk}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
j = r.json()
return j['length']
except Exception as e:
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
return len(tokenizer.encode(prompt)) + 10
else:
# If the result was greater than our context size, return the estimate.
# We won't be sending it through the backend so it does't need to be accurage.
return initial_estimate
logger.debug(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
return len(tokenizer.encode(chunk)) + 10
# Use a ThreadPoolExecutor to send all chunks to the server at once
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_chunk = {executor.submit(send_chunk, chunk): chunk for chunk in chunks}
for future in concurrent.futures.as_completed(future_to_chunk):
chunk = future_to_chunk[future]
try:
data = future.result()
except Exception as exc:
logger.warning('%r generated an exception: %s' % (chunk, exc))
return sum(future.result() for future in future_to_chunk)

View File

@ -1,10 +1,9 @@
import threading
from typing import Tuple
from flask import jsonify
from vllm import SamplingParams
from llm_server.database.database import log_prompt
from llm_server.database.log_to_db import log_to_db
from llm_server.llm.llm_backend import LLMBackend
@ -19,16 +18,8 @@ class VLLMBackend(LLMBackend):
# Failsafe
backend_response = ''
r_url = request.url
def background_task():
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
return jsonify({'results': [{'text': backend_response}]}), 200
@ -38,14 +29,20 @@ class VLLMBackend(LLMBackend):
top_k = parameters.get('top_k', self._default_params['top_k'])
if top_k <= 0:
top_k = -1
# TODO: support more params
sampling_params = SamplingParams(
temperature=parameters.get('temperature', self._default_params['temperature']),
top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=top_k,
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=parameters.get('stopping_strings', self._default_params['stop']),
stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))),
ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']),
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']),
early_stopping=parameters.get('early_stopping', self._default_params['early_stopping'])
)
except ValueError as e:
return None, str(e).strip('.')

52
llm_server/logging.py Normal file
View File

@ -0,0 +1,52 @@
import logging
import coloredlogs
from llm_server import opts
class LoggingInfo:
def __init__(self):
self._level = logging.INFO
self._format = opts.LOGGING_FORMAT
@property
def level(self):
return self._level
@level.setter
def level(self, value):
self._level = value
@property
def format(self):
return self._format
@format.setter
def format(self, value):
self._format = value
logging_info = LoggingInfo()
def init_logging():
"""
Set up the parent logger.
:return:
"""
logger = logging.getLogger('llm_server')
logger.setLevel(logging_info.level)
def create_logger(name):
logger = logging.getLogger('llm_server').getChild(name)
logger.setLevel(logging_info.level)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging_info.level)
formatter = logging.Formatter(logging_info.format)
handler.setFormatter(formatter)
logger.addHandler(handler)
coloredlogs.install(logger=logger, level=logging_info.level)
return logger

1
llm_server/messages.py Normal file
View File

@ -0,0 +1 @@
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'

View File

@ -1,52 +0,0 @@
import json
from datetime import datetime, timedelta
import requests
from llm_server import opts
def get_power_states():
gpu_num = 0
output = {}
while True:
url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state"
try:
response = requests.get(url, timeout=10)
if response.status_code != 200:
break
data = json.loads(response.text)
power_state_data = data['data'][0]
power_state = None
for i in range(1, len(power_state_data)):
if power_state_data[i] == 1:
power_state = data['labels'][i]
break
output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p'))
except Exception as e:
print('Failed to fetch Netdata metrics:', e)
return output
gpu_num += 1
return output
def get_gpu_wh(gpu_id: int):
chart_name = f"nvidia_smi.gpu{gpu_id}_power"
now = datetime.now()
one_hour_ago = now - timedelta(hours=1)
num_seconds = int((now - one_hour_ago).total_seconds())
params = {
"chart": chart_name,
"after": int(one_hour_ago.timestamp()),
"before": int(now.timestamp()),
"points": num_seconds,
"group": "second",
"format": "json",
"options": "absolute|jsonwrap"
}
response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10)
data = json.loads(response.text)
total_power_usage_watts = sum(point[1] for point in data['result']['data'])
# total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1)
total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3)
return total_power_usage_kwh

View File

@ -1,12 +1,11 @@
# Read-only global variables
# Uppercase variables are read-only globals.
# Lowercase variables are ones that are set on startup and are never changed.
# TODO: rewrite the config system so I don't have to add every single config default here
running_model = 'ERROR'
concurrent_gens = 3
mode = 'oobabooga'
backend_url = None
context_size = 5555
frontend_api_mode = 'ooba'
max_new_tokens = 500
auth_required = False
log_prompts = False
@ -33,7 +32,15 @@ openai_expose_our_model = False
openai_force_no_hashes = True
include_system_tokens_in_stats = True
openai_moderation_scan_last_n = 5
openai_moderation_workers = 10
openai_org_name = 'OpenAI'
openai_silent_trim = False
openai_moderation_enabled = True
cluster = {}
show_backends = True
background_homepage_cacher = True
openai_moderation_timeout = 5
prioritize_by_size = False
cluster_workers = 0
redis_stream_timeout = 25000
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"

View File

@ -1,21 +1,9 @@
import sys
from redis import Redis
from llm_server.routes.cache import redis
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.custom_redis import redis
def server_startup(s):
if not redis.get('daemon_started', bool):
if not redis.get('daemon_started', dtype=bool):
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
sys.exit(1)
# Flush the RedisPriorityQueue database.
queue_redis = Redis(host='localhost', port=6379, db=15)
for key in queue_redis.scan_iter('*'):
queue_redis.delete(key)
# Cache the initial stats
print('Loading backend stats...')
generate_stats()

View File

@ -1,11 +1,18 @@
from llm_server import opts
from llm_server.routes.cache import redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
def format_sillytavern_err(msg: str, level: str = 'info'):
http_host = redis.get('http_host', str)
def format_sillytavern_err(msg: str, backend_url: str = None, error_type: str = 'info'):
if backend_url:
cluster_backend_hash = cluster_config.get_backend(backend_url)['hash']
else:
cluster_backend_hash = 'none'
http_host = redis.get('http_host', dtype=str)
return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
-> {level.upper()} <-
-> {error_type.upper()} <-
{msg}
```
```
BACKEND: {cluster_backend_hash}
```"""

View File

@ -100,4 +100,4 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas
j = json.loads(str(data))
return True, j
except Exception as e:
return False, e
return False, e

View File

@ -0,0 +1,15 @@
def estimate_model_size(config: dict):
"""
Estimate the size of a model from its config. No idea if this is correct,
but it allows us to compare models.
:param config:
:return:
"""
vocab_size = config.get('vocab_size')
hidden_size = config.get('hidden_size')
num_hidden_layers = config.get('num_hidden_layers')
intermediate_size = config.get('intermediate_size')
if vocab_size and hidden_size and num_hidden_layers and intermediate_size:
total_params = (vocab_size * hidden_size) + (num_hidden_layers * ((hidden_size * intermediate_size * 4) + (hidden_size * hidden_size * 3)))
return int(total_params / 1e9)
return 0

View File

@ -3,8 +3,8 @@ from typing import Tuple
import flask
from flask import jsonify, request
from llm_server import opts
from llm_server.database.database import log_prompt
from llm_server import messages, opts
from llm_server.database.log_to_db import log_to_db
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler
@ -13,8 +13,11 @@ class OobaRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def handle_request(self):
def handle_request(self, return_ok: bool = True):
assert not self.used
if self.offline:
print('This backend is offline:', messages.BACKEND_OFFLINE)
return self.handle_error(messages.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request()
if not request_valid:
@ -25,14 +28,19 @@ class OobaRequestHandler(RequestHandler):
llm_request = {**self.parameters, 'prompt': prompt}
_, backend_response = self.generate_response(llm_request)
return backend_response
if return_ok:
# Always return 200 so ST displays our error messages
return backend_response[0], 200
else:
# The OpenAI route needs to detect 429 errors.
return backend_response
def handle_ratelimited(self, do_log: bool = True):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg)
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return backend_response[0], 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
@ -40,7 +48,7 @@ class OobaRequestHandler(RequestHandler):
# TODO: how to format this
response_msg = error_msg
else:
response_msg = format_sillytavern_err(error_msg, error_type)
response_msg = format_sillytavern_err(error_msg, error_type=error_type, backend_url=self.backend_url)
return jsonify({
'results': [{'text': response_msg}]

View File

@ -5,9 +5,11 @@ from ..server_error import handle_server_error
from ... import opts
openai_bp = Blueprint('openai/v1/', __name__)
openai_model_bp = Blueprint('openai/', __name__)
@openai_bp.before_request
@openai_model_bp.before_request
def before_oai_request():
if not opts.enable_openi_compatible_backend:
return 'The OpenAI-compatible backend is disabled.', 401
@ -15,8 +17,22 @@ def before_oai_request():
@openai_bp.errorhandler(500)
@openai_model_bp.errorhandler(500)
def handle_error(e):
return handle_server_error(e)
"""
Found Codes:
"auth_subrequest_error"
"""
print('OAI returning error:', e)
return jsonify({
"error": {
"message": "Internal server error",
"type": "auth_subrequest_error",
"param": None,
"code": "internal_error"
}
}), 500
from .models import openai_list_models

View File

@ -1,113 +1,175 @@
import json
import threading
import time
import traceback
import ujson
from flask import Response, jsonify, request
from redis import Redis
from . import openai_bp
from ..cache import redis
from llm_server.custom_redis import redis
from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue
from ... import opts
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
from ...llm.vllm import tokenize
from ...database.log_to_db import log_to_db
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
# TODO: add rate-limit headers?
@openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions():
@openai_model_bp.route('/<model_name>/v1/chat/completions', methods=['POST'])
def openai_chat_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else:
handler = OpenAIRequestHandler(request, request_json_body)
if request_json_body.get('stream'):
if not opts.enable_streaming:
# TODO: return a proper OAI error message
return 'disabled', 401
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline:
return return_invalid_model_err(model_name)
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception('TODO: simulate OAI here')
else:
handler.prompt = transform_messages_to_prompt(request_json_body['messages'])
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
try:
response = generator(msg_to_backend)
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
data = {
"id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
return Response(generate(), mimetype='text/event-stream')
except:
# TODO: simulate OAI here
raise Exception
else:
if not request_json_body.get('stream'):
try:
return handler.handle_request()
except Exception:
traceback.print_exc()
return 'Internal server error', 500
else:
if not opts.enable_streaming:
return 'Streaming disabled', 403
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
handler.parameters, e = handler.get_parameters()
handler.request_json_body = {
'messages': handler.request_json_body['messages'],
'model': handler.request_json_body['model'],
**handler.parameters
}
if opts.openai_silent_trim:
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
else:
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
if not handler.prompt:
# Prevent issues on the backend.
return 'Invalid prompt', 400
# Need to set the prompt in the JSON body since that's what the inference worker expects.
handler.request_json_body['prompt'] = handler.prompt
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
429,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
try:
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
# Need to do this before we enter generate() since we want to be able to
# return a 408 if necessary.
_, stream_name, error_msg = event.wait()
if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None # set to null so that the Finally ignores it.
return 'Request Timeout', 408
def generate():
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0'
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data'])
if data['error']:
# Not printing error since we can just check the daemon log.
print('OAI streaming encountered error')
yield 'data: [DONE]\n\n'
return
elif data['new']:
response = {
"id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": data['new']
},
"finish_reason": None
}
]
}
generated_text = generated_text + data['new']
yield f'data: {json.dumps(response)}\n\n'
elif data['completed']:
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
200,
r_url,
handler.backend_url,
)
return
except GeneratorExit:
return
except Exception:
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
if event:
redis.publish(f'notifications:{event.event_id}', 'canceled')
if stream_name:
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')
except Exception:
traceback.print_exc()
return 'INTERNAL SERVER', 500

View File

@ -1,38 +1,68 @@
import time
import traceback
from flask import jsonify, make_response, request
import simplejson as json
import ujson
from flask import Response, jsonify, request
from redis import Redis
from . import openai_bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from llm_server.custom_redis import redis
from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import priority_queue
from ... import opts
from ...database.log_to_db import log_to_db
from ...llm import get_token_count
from ...llm.openai.transform import build_openai_response, generate_oai_string
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
# TODO: add rate-limit headers?
@openai_bp.route('/completions', methods=['POST'])
def openai_completions():
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
def openai_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
try:
response, status_code = OobaRequestHandler(request).handle_request()
if status_code != 200:
return status_code
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline:
return return_invalid_model_err(model_name)
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim:
handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
# The handle_request() call below will load the prompt so we don't have
# to do anything else here.
pass
handler.request_json_body['prompt'] = handler.prompt
if not request_json_body.get('stream'):
invalid_oai_err_msg = validate_oai(request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
response, status_code = handler.handle_request(return_ok=False)
if status_code == 429:
return handler.handle_ratelimited()
output = response.json['results'][0]['text']
# TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'])
response_tokens = get_token_count(output)
running_model = redis.get('running_model', str, 'ERROR')
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
response_tokens = get_token_count(output, handler.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
response = jsonify({
"id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion",
"created": int(time.time()),
@ -42,7 +72,7 @@ def openai_completions():
"text": output,
"index": 0,
"logprobs": None,
"finish_reason": None
"finish_reason": "stop"
}
],
"usage": {
@ -50,12 +80,141 @@ def openai_completions():
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
})
stats = redis.get('proxy_stats', dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
except Exception:
traceback.print_exc()
return 'Internal Server Error', 500
# TODO:
# stats = redis.get('proxy_stats', dtype=dict)
# if stats:
# response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response, 200
else:
if not opts.enable_streaming:
return 'Streaming disabled', 403
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
handler.parameters, _ = handler.get_parameters()
handler.request_json_body = {
'prompt': handler.request_json_body['prompt'],
'model': handler.request_json_body['model'],
**handler.parameters
}
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
if not handler.prompt:
# Prevent issues on the backend.
return 'Invalid prompt', 400
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
429,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
try:
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
_, stream_name, error_msg = event.wait()
if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None
return 'Request Timeout', 408
def generate():
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0'
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data'])
if data['error']:
print('OAI streaming encountered error')
yield 'data: [DONE]\n\n'
return
elif data['new']:
response = {
"id": f"cmpl-{oai_string}",
"object": "text_completion",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": data['new']
},
"finish_reason": None
}
]
}
generated_text = generated_text + data['new']
yield f'data: {json.dumps(response)}\n\n'
elif data['completed']:
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
200,
r_url,
handler.backend_url,
)
return
except GeneratorExit:
# This should be triggered if a client disconnects early.
return
except Exception:
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
if event:
redis.publish(f'notifications:{event.event_id}', 'canceled')
if stream_name:
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')
except Exception:
traceback.print_exc()
return 'INTERNAL SERVER', 500

View File

@ -1,7 +1,7 @@
from flask import Response
from . import openai_bp
from ..cache import flask_cache
from llm_server.custom_redis import flask_cache
from ... import opts

View File

@ -3,59 +3,58 @@ import traceback
import requests
from flask import jsonify
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, flask_cache, redis
from ..stats import server_start_time
from ... import opts
from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
from ...helpers import jsonify_pretty
from ...llm.info import get_running_model
from ...llm.openai.transform import generate_oai_string
@openai_bp.route('/models', methods=['GET'])
@flask_cache.cached(timeout=60, query_string=True)
def openai_list_models():
model, error = get_running_model()
if not model:
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
if not model_name:
response = jsonify({
'code': 502,
'msg': 'failed to reach backend',
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
running_model = redis.get('running_model', str, 'ERROR')
running_model = redis.get('running_model', 'ERROR', dtype=str)
oai = fetch_openai_models()
r = []
r = {
"object": "list",
"data": oai
}
# TODO: verify this works
if opts.openai_expose_our_model:
r = [{
"object": "list",
"data": [
r["data"].insert(0, {
"id": running_model,
"object": "model",
"created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name,
"permission": [
{
"id": running_model,
"object": "model",
"object": "model_permission",
"created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name,
"permission": [
{
"id": running_model,
"object": "model_permission",
"created": int(server_start_time.timestamp()),
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": True,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False
}
],
"root": None,
"parent": None
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": True,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False
}
]
}]
response = jsonify_pretty(r + oai), 200
],
"root": None,
"parent": None
})
response = jsonify_pretty(r), 200
return response
@ -64,7 +63,14 @@ def fetch_openai_models():
if opts.openai_api_key:
try:
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
return response.json()['data']
j = response.json()['data']
# The "modelperm" string appears to be user-specific, so we'll
# randomize it just to be safe.
for model in range(len(j)):
for p in range(len(j[model]['permission'])):
j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}'
return j
except:
traceback.print_exc()
return []

View File

@ -1,7 +1,7 @@
from flask import jsonify
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, flask_cache
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache
from ...llm.openai.transform import generate_oai_string
from ..stats import server_start_time
@ -17,7 +17,7 @@ def openai_organizations():
"id": f"org-{generate_oai_string(24)}",
"created": int(server_start_time.timestamp()),
"title": "Personal",
"name": "user-abcdefghijklmnopqrstuvwx",
"name": f"user-{generate_oai_string(24)}",
"description": "Personal org for bobjoe@0.0.0.0",
"personal": True,
"is_default": True,

View File

@ -1,14 +1,21 @@
import json
import re
import time
import traceback
from typing import Tuple
from uuid import uuid4
import flask
from flask import jsonify
from flask import Response, jsonify, make_response
from llm_server import opts
from llm_server.cluster.backend import get_model_choices
from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results
@ -20,20 +27,37 @@ class OpenAIRequestHandler(RequestHandler):
def handle_request(self) -> Tuple[flask.Response, int]:
assert not self.used
if self.offline:
msg = return_invalid_model_err(self.selected_model)
print('OAI Offline:', msg)
return self.handle_error(msg)
if opts.openai_silent_trim:
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
else:
oai_messages = self.request.json['messages']
self.prompt = transform_messages_to_prompt(oai_messages)
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
request_valid, invalid_response = self.validate_request()
if not request_valid:
return invalid_response
if opts.openai_api_key and is_api_key_moderated(self.token):
if not self.prompt:
# TODO: format this as an openai error message
return Response('Invalid prompt'), 400
# TODO: support Ooba backend
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
invalid_oai_err_msg = validate_oai(self.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
try:
# Gather the last message from the user and all preceeding system messages
# Gather the last message from the user and all preceding system messages
msg_l = self.request.json['messages'].copy()
msg_l.reverse()
tag = uuid4()
@ -49,33 +73,40 @@ class OpenAIRequestHandler(RequestHandler):
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
except Exception as e:
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc())
# Reconstruct the request JSON with the validated parameters and prompt.
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
if opts.openai_force_no_hashes:
self.parameters['stop'].append('### ')
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
self.request_json_body['top_p'] = 0.01
traceback.print_exc()
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
model = self.request_json_body.get('model')
if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
else:
return backend_response, backend_response_status_code
def handle_ratelimited(self, do_log: bool = True):
# TODO: return a simulated OpenAI error message
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.
return 'Ratelimited', 429
model_choices, default_model = get_model_choices()
default_model_info = model_choices[default_model]
w = int(default_model_info['estimated_wait']) if default_model_info['estimated_wait'] > 0 else 2
response = jsonify({
"error": {
"message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.",
"type": "rate_limit_exceeded",
"param": None,
"code": None
}
})
response.headers['x-ratelimit-limit-requests'] = '2'
response.headers['x-ratelimit-remaining-requests'] = '0'
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
if do_log:
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
# TODO: return a simulated OpenAI error message
print('OAI Error:', error_msg)
return jsonify({
"error": {
"message": "Invalid request, check your parameters and try again.",
@ -84,3 +115,51 @@ class OpenAIRequestHandler(RequestHandler):
"code": None
}
}), 400
def build_openai_response(self, prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
prompt_tokens = get_token_count(prompt, self.backend_url)
response_tokens = get_token_count(response, self.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
return response
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
self.parameters, parameters_invalid_msg = self.get_parameters()
if not self.parameters:
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
return False, (Response('Invalid request, check your parameters and try again.'), 400)
invalid_oai_err_msg = validate_oai(self.parameters)
if invalid_oai_err_msg:
return False, invalid_oai_err_msg
# self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
# If the parameters were invalid, let the superclass deal with it.
return super().validate_request(prompt, do_log)

View File

@ -1,12 +1,15 @@
import json
import pickle
import time
from typing import Tuple
from uuid import uuid4
import ujson as json
from redis import Redis
from llm_server import opts
from llm_server.routes.cache import redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis
from llm_server.database.database import get_token_ratelimit
def increment_ip_count(client_ip: str, redis_key):
@ -20,24 +23,30 @@ def decrement_ip_count(client_ip: str, redis_key):
class RedisPriorityQueue:
def __init__(self):
self.redis = Redis(host='localhost', port=6379, db=15)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe('events')
"""
A queue for a specific backend.
"""
def put(self, item, priority):
event = DataEvent()
def __init__(self, name, db: int = 12):
self.name = name
self.redis = RedisCustom(name, db=db)
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
# TODO: remove this when we're sure nothing strange is happening
assert item is not None
assert priority is not None
assert selected_model is not None
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count:
ip_count = int(ip_count)
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
ip_count = self.get_ip_request_count(item[1])
_, simultaneous_ip = get_token_ratelimit(item[2])
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
return None # reject the request
self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority})
self.increment_ip_count(item[1], 'queued_ip_count')
timestamp = time.time()
event = DataEvent()
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
return event
def get(self):
@ -45,31 +54,59 @@ class RedisPriorityQueue:
data = self.redis.zpopmin('queue')
if data:
item = json.loads(data[0][0])
client_ip = item[0][1]
self.decrement_ip_count(client_ip, 'queued_ip_count')
return item
time.sleep(0.1) # wait for something to be added to the queue
def increment_ip_count(self, client_ip: str, redis_key):
self.redis.hincrby(redis_key, client_ip, 1)
def decrement_ip_count(self, client_ip: str, redis_key):
new_count = self.redis.hincrby(redis_key, client_ip, -1)
if new_count <= 0:
self.redis.hdel(redis_key, client_ip)
def __len__(self):
return self.redis.zcard('queue')
def get_queued_ip_count(self, client_ip: str):
q = self.redis.hget('queued_ip_count', client_ip)
if not q:
return 0
return 0
def get_ip_request_count(self, client_ip: str):
"""
Get the number of requests in the queue from a specific IP.
This is a bit inefficient since we iterate over the entire queue, but
keeps the queue as a single point of truth instead of tracking a separate hashed
set which can get confusing.
If we run into slowdowns in the future, we should go back to the hashed set approach.
:param client_ip:
:return:
"""
start_time = time.time()
items = self.redis.zrange('queue', 0, -1)
count = 0
for item in items:
item_data = json.loads(item)
if item_data[0][1] == client_ip:
count += 1
elapsed_time = time.time() - start_time
if elapsed_time > 0.5:
raise Exception(f"!!! get_ip_request_count took {elapsed_time} seconds to execute !!!")
return count
def flush(self):
self.redis.flush()
def items(self):
return self.redis.zrange('queue', 0, -1)
def cleanup(self):
now = time.time()
for item in self.items():
item_data = json.loads(item)
timestamp = item_data[-2]
if now - timestamp > opts.backend_generate_request_timeout:
self.redis.zrem('queue', 0, item)
event_id = item_data[1]
event = DataEvent(event_id)
event.set((False, None, 'closed'))
print('Removed timed-out item from queue:', event_id)
class DataEvent:
def __init__(self, event_id=None):
"""
Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol).
"""
def __init__(self, event_id: str = None):
self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub()
@ -84,15 +121,89 @@ class DataEvent:
return pickle.loads(item['data'])
priority_queue = RedisPriorityQueue()
def update_active_workers(key: str, operation: str):
if operation == 'incr':
redis.incr(f'active_gen_workers:{key}')
elif operation == 'decr':
redis.decr(f'active_gen_workers:{key}')
if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0:
redis.set(f'active_gen_workers:{key}', 0)
def incr_active_workers():
redis.incr('active_gen_workers')
def incr_active_workers(selected_model: str, backend_url: str):
update_active_workers(selected_model, 'incr')
update_active_workers(backend_url, 'incr')
def decr_active_workers():
redis.decr('active_gen_workers')
new_count = redis.get('active_gen_workers', int, 0)
if new_count < 0:
redis.set('active_gen_workers', 0)
def decr_active_workers(selected_model: str, backend_url: str):
update_active_workers(selected_model, 'decr')
update_active_workers(backend_url, 'decr')
class PriorityQueue:
"""
Helper class to wrangler all the different queues.
"""
def __init__(self, backends: set = None):
"""
Only have to load the backends once.
:param backends:
"""
self.redis = Redis(host='localhost', port=6379, db=9)
if backends:
for item in backends:
self.redis.sadd('backends', item)
def get_backends(self):
return {x.decode('utf-8') for x in self.redis.smembers('backends')}
def get_queued_ip_count(self, client_ip: str):
count = 0
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
count += queue.get_ip_request_count(client_ip)
return count
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
queue = RedisPriorityQueue(backend_url)
return queue.put(item, priority, selected_model, do_stream)
def activity(self):
lines = []
status_redis = RedisCustom('worker_status')
for worker in status_redis.keys():
lines.append((worker, status_redis.getp(worker)))
return sorted(lines)
def len(self, model_name):
count = 0
backends_with_models = set()
for k in self.get_backends():
info = cluster_config.get_backend(k)
if info.get('model') == model_name:
backends_with_models.add(k)
for backend_url in backends_with_models:
count += len(RedisPriorityQueue(backend_url))
return count
def __len__(self):
count = 0
p = set()
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
p.add((backend_url, len(queue)))
count += len(queue)
return count
def flush(self):
for k in self.redis.keys():
q = json.loads(self.redis.get(k))
q.flush()
self.redis.set(k, json.dumps(q))
def flush_db(self):
self.redis.flushdb()
priority_queue = PriorityQueue()

View File

@ -5,23 +5,22 @@ import flask
from flask import Response, request
from llm_server import opts
from llm_server.database.conn import database
from llm_server.database.database import log_prompt
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
from llm_server.custom_redis import redis
from llm_server.database.database import get_token_ratelimit
from llm_server.database.log_to_db import log_to_db
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token
from llm_server.routes.cache import redis
from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue
DEFAULT_PRIORITY = 9999
class RequestHandler:
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
self.request = incoming_request
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
# self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
# Routes need to validate it, here we just load it
if incoming_json:
@ -34,11 +33,38 @@ class RequestHandler:
self.start_time = time.time()
self.client_ip = self.get_client_ip()
self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
self.backend = get_backend()
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
self.parameters = None
self.used = False
redis.zadd('recent_prompters', {self.client_ip: time.time()})
# This is null by default since most handlers need to transform the prompt in a specific way.
self.prompt = None
self.selected_model = selected_model
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
# Debug stuff
# if not self.cluster_backend_info.get('mode'):
# print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
# if not self.cluster_backend_info.get('model'):
# print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
# if not self.cluster_backend_info.get('model_config'):
# print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
self.offline = True
else:
self.offline = False
self.selected_model = self.cluster_backend_info['model']
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
if self.token and not self.token.startswith('SYSTEM__'):
# "recent_prompters" is only used for stats.
redis.zadd('recent_prompters', {self.client_ip: time.time()})
def check_online(self) -> bool:
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
return self.cluster_backend_info['online']
def get_auth_token(self):
if self.request_json_body.get('X-API-KEY'):
@ -49,6 +75,8 @@ class RequestHandler:
return parse_token(self.request.headers['Authorization'])
def get_client_ip(self):
if self.request.headers.get('Llm-Connecting-Ip'):
return self.request.headers['Llm-Connecting-Ip']
if self.request.headers.get('X-Connecting-IP'):
return self.request.headers.get('X-Connecting-IP')
elif self.request.headers.get('Cf-Connecting-Ip'):
@ -58,26 +86,7 @@ class RequestHandler:
else:
return self.request.remote_addr
def get_token_ratelimit(self):
priority = DEFAULT_PRIORITY
simultaneous_ip = opts.simultaneous_requests_per_ip
if self.token:
cursor = database.cursor()
try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
result = cursor.fetchone()
if result:
priority, simultaneous_ip = result
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip
def get_parameters(self):
if self.request_json_body.get('max_tokens'):
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg
@ -119,7 +128,7 @@ class RequestHandler:
backend_response = self.handle_error(combined_error_message, 'Validation Error')
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
return False, backend_response
return True, (None, 0)
@ -131,14 +140,18 @@ class RequestHandler:
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid:
return (False, None, None, 0), invalid_response
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
event = priority_queue.put(self.backend_url, (llm_request, self.client_ip, self.token, self.parameters), self.token_priority, self.selected_model)
else:
event = None
if not event:
return (False, None, None, 0), self.handle_ratelimited()
# TODO: add wait timeout
success, response, error_msg = event.wait()
if error_msg == 'closed':
return (False, None, None, 0), (self.handle_error('Request Timeout')[0], 408)
end_time = time.time()
elapsed_time = end_time - self.start_time
@ -160,7 +173,17 @@ class RequestHandler:
else:
error_msg = error_msg.strip('.') + '.'
backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
log_to_db(ip=self.client_ip,
token=self.token,
prompt=prompt,
response=backend_response[0].data.decode('utf-8'),
gen_time=None,
parameters=self.parameters,
headers=dict(self.request.headers),
backend_response_code=response_status_code,
request_url=self.request.url,
backend_url=self.backend_url,
is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
@ -180,7 +203,7 @@ class RequestHandler:
if return_json_err:
error_msg = 'The backend did not return valid JSON.'
backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
@ -189,22 +212,29 @@ class RequestHandler:
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def is_client_ratelimited(self) -> bool:
if self.token_priority == 0:
return False
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
x = redis.hget('processing_ips', self.client_ip)
if x:
processing_ip = int(x)
else:
processing_ip = 0
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
return False
else:
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
return True
else:
return False
def handle_request(self) -> Tuple[flask.Response, int]:
# Must include this in your child.
# if self.used:
# raise Exception('Can only use a RequestHandler object once.')
# assert not self.used
# if self.offline:
# msg = f'{self.selected_model} is not a valid model choice.'
# print(msg)
# return format_sillytavern_err(msg)
raise NotImplementedError
def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]:
@ -214,11 +244,11 @@ class RequestHandler:
raise NotImplementedError
def get_backend():
if opts.mode == 'oobabooga':
return OobaboogaBackend()
elif opts.mode == 'vllm':
return VLLMBackend()
def get_backend_handler(mode, backend_url: str):
if mode == 'oobabooga':
return OobaboogaBackend(backend_url)
elif mode == 'vllm':
return VLLMBackend(backend_url)
else:
raise Exception

View File

@ -1,3 +1,3 @@
def handle_server_error(e):
print(e)
return {'error': True}, 500
print('Internal Error:', e)
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500

View File

@ -1,33 +1,11 @@
from datetime import datetime
from llm_server.routes.cache import redis
# proompters_5_min = 0
# concurrent_semaphore = Semaphore(concurrent_gens)
from llm_server.custom_redis import redis
from llm_server.helpers import round_up_base
server_start_time = datetime.now()
# TODO: do I need this?
# def elapsed_times_cleanup():
# global wait_in_queue_elapsed
# while True:
# current_time = time.time()
# with wait_in_queue_elapsed_lock:
# global wait_in_queue_elapsed
# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
# time.sleep(1)
def calculate_avg_gen_time():
# Get the average generation time from Redis
average_generation_time = redis.get('average_generation_time')
if average_generation_time is None:
return 0
else:
return float(average_generation_time)
def get_total_proompts():
count = redis.get('proompts')
if count is None:
@ -37,10 +15,27 @@ def get_total_proompts():
return count
def get_active_gen_workers():
active_gen_workers = redis.get('active_gen_workers')
if active_gen_workers is None:
count = 0
def get_active_gen_workers_model(selected_model: str = None):
return redis.get(f'active_gen_workers:{selected_model}', dtype=int, default=0)
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
if active_gen_workers < concurrent_gens:
return 0
elif active_gen_workers >= concurrent_gens:
# Calculate how long it will take to complete the currently running gens and the queued requests.
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
# Add gen_time_calc to the time to account for the currently running generations.
# This assumes that all active workers will finish at the same time, which is unlikely.
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
elif proompters_in_queue == 0 and active_gen_workers == 0:
# No queue, no workers
return 0
else:
count = int(active_gen_workers)
return count
# No queue
return gen_time_calc

View File

@ -3,18 +3,18 @@ import traceback
from flask import jsonify, request
from . import bp
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
@bp.route('/generate', methods=['POST'])
def generate():
@bp.route('/v1/generate', methods=['POST'])
@bp.route('/<model_name>/v1/generate', methods=['POST'])
def generate(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
handler = OobaRequestHandler(request)
handler = OobaRequestHandler(request, selected_model=model_name)
try:
return handler.handle_request()
except Exception:

View File

@ -2,83 +2,30 @@ import time
from datetime import datetime
from llm_server import opts
from llm_server.cluster.backend import get_model_choices
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.database.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort, round_up_base
from llm_server.llm.info import get_running_model
from llm_server.netdata import get_power_states
from llm_server.routes.cache import redis
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
from llm_server.helpers import deep_sort
from llm_server.routes.stats import get_total_proompts, server_start_time
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
if active_gen_workers < concurrent_gens:
return 0
elif active_gen_workers >= concurrent_gens:
# Calculate how long it will take to complete the currently running gens and the queued requests.
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
# Add gen_time_calc to the time to account for the currently running generations.
# This assumes that all active workers will finish at the same time, which is unlikely.
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
elif proompters_in_queue == 0 and active_gen_workers == 0:
# No queue, no workers
return 0
else:
# No queue
return gen_time_calc
# TODO: have routes/__init__.py point to the latest API version generate_stats()
def generate_stats(regen: bool = False):
if not regen:
c = redis.get('proxy_stats', dict)
c = redis.getp('proxy_stats')
if c:
return c
model_name, error = get_running_model() # will return False when the fetch fails
if isinstance(model_name, bool):
online = False
else:
online = True
redis.set('running_model', model_name)
model_choices, default_model = get_model_choices(regen=True)
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
# if len(t) == 0:
# estimated_wait = 0
# else:
# waits = [elapsed for end, elapsed in t]
# estimated_wait = int(sum(waits) / len(waits))
active_gen_workers = get_active_gen_workers()
proompters_in_queue = len(priority_queue)
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
if opts.netdata_root:
netdata_stats = {}
power_states = get_power_states()
for gpu, power_state in power_states.items():
netdata_stats[gpu] = {
'power_state': power_state,
# 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
}
else:
netdata_stats = {}
base_client_api = redis.get('base_client_api', str)
base_client_api = redis.get('base_client_api', dtype=str)
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
output = {
'models': {
'choices': model_choices,
'default': default_model,
},
'stats': {
'proompters': {
'5_min': proompters_5_min,
@ -86,39 +33,49 @@ def generate_stats(regen: bool = False):
},
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
'average_generation_elapsed_sec': int(average_generation_time),
# 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
'num_backends': len(cluster_config.all()) if opts.show_backends else None,
},
'online': online,
'endpoints': {
'blocking': f'https://{base_client_api}',
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
},
'queue': {
'processing': active_gen_workers,
'queued': proompters_in_queue,
'estimated_wait_sec': int(estimated_wait_sec),
},
'timestamp': int(time.time()),
'config': {
'gatekeeper': 'none' if opts.auth_required is False else 'token',
'context_size': opts.context_size,
'concurrent': opts.concurrent_gens,
'model': opts.manual_model_name if opts.manual_model_name else model_name,
'mode': opts.mode,
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
'api_mode': opts.frontend_api_mode
},
'keys': {
'openaiKeys': '',
'anthropicKeys': '',
},
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
'nvidia': netdata_stats
'backends': {},
'online': len(model_choices) > 0
}
# TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself
if opts.show_backends:
for backend_url, v in cluster_config.all().items():
backend_info = cluster_config.get_backend(backend_url)
if not backend_info['online']:
continue
backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None
output['backends'][backend_info['hash']] = {
'uptime': backend_uptime,
'max_tokens': backend_info['model_config'].get('max_position_embeddings', -1),
'model': backend_info['model'],
'mode': backend_info['mode'],
'nvidia': backend_info['nvidia'],
'priority': backend_info['priority'],
}
result = deep_sort(output)
# It may take a bit to get the base client API, so don't cache until then.
if base_client_api:
redis.set_dict('proxy_stats', result) # Cache with no expiry
redis.setp('proxy_stats', result)
return result

View File

@ -1,186 +1,200 @@
import json
import threading
import time
import traceback
from typing import Union
import ujson
from flask import request
from redis import Redis
from ..cache import redis
from . import bp
from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ..queue import priority_queue
from ... import opts
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.vllm import tokenize
from ...stream import sock
from ...custom_redis import redis
from ...database.log_to_db import log_to_db
from ...sock import sock
# TODO: have workers process streaming requests
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
# We solve this by splitting the routes
@sock.route('/api/v1/stream')
def stream(ws):
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close()
log_in_bg(quitting_err_msg, is_error=True)
@bp.route('/v1/stream')
@bp.route('/<model_name>/v1/stream')
def stream(model_name=None):
return 'This is a websocket endpoint.', 400
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
def background_task_exception():
generated_tokens = tokenize(generated_text_bg)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error)
@sock.route('/v1/stream', bp=bp)
def stream_without_model(ws):
do_stream(ws, model_name=None)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
if not opts.enable_streaming:
return 'Streaming is disabled', 401
@sock.route('/<model_name>/v1/stream', bp=bp)
def stream_with_model(ws, model_name=None):
do_stream(ws, model_name)
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else:
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
auth_failure = require_api_key(request_json_body)
if auth_failure:
return auth_failure
def do_stream(ws, model_name):
event_id = None
try:
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close()
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None,
is_error=True
)
handler = OobaRequestHandler(request, request_json_body)
generated_text = ''
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
if not opts.enable_streaming:
return 'Streaming disabled', 403
err_msg = None
if handler.is_client_ratelimited():
r, _ = handler.handle_ratelimited(do_log=False)
err_msg = r.json['results'][0]['text']
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else:
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
if err_msg:
send_err_and_quit(err_msg)
return
# We have to do auth ourselves since the details are sent in the message.
auth_failure = require_api_key(request_json_body)
if auth_failure:
return auth_failure
llm_request = {
**handler.parameters,
'prompt': input_prompt,
'stream': True,
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority)
if not event:
r, _ = handler.handle_ratelimited()
err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg)
return
try:
response = generator(llm_request)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
if handler.offline:
msg = f'{handler.selected_model} is not a valid model choice.'
print(msg)
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'message_num': 0,
'text': msg
}))
return
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
err_msg = None
if handler.is_client_ratelimited():
r, _ = handler.handle_ratelimited(do_log=False)
err_msg = r.json['results'][0]['text']
else:
# Be extra careful when getting attributes from the response object
try:
response_status_code = response.status_code
except:
response_status_code = 0
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
if err_msg:
send_err_and_quit(err_msg)
return
partial_response = b''
handler.parameters, _ = handler.get_parameters()
handler.prompt = input_prompt
handler.request_json_body = {
'prompt': handler.prompt,
**handler.parameters
}
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
try:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
except:
# The has client closed the stream.
if request:
request.close()
ws.close()
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
return
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
r = handler.handle_ratelimited()
send_err_and_quit(r[0].data)
return
event_id = event.event_id
_, stream_name, error_msg = event.wait()
if error_msg:
print('Stream failed to start streaming:', error_msg)
ws.close(reason=1014, message='Request Timeout')
return
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0' # The ID of the last entry we read.
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
return
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
data = ujson.loads(item[b'data'])
if data['error']:
print(data['error'])
send_err_and_quit('Encountered exception while streaming.')
return
elif data['new']:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': data['new']
}))
message_num += 1
partial_response = b'' # Reset the partial response
# If there is no more data, break the loop
if not chunk:
break
end_time = time.time()
elapsed_time = end_time - start_time
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
except:
traceback.print_exc()
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': generated_text
}))
if request:
request.close()
ws.close()
log_in_bg(generated_text, is_error=True, status_code=response_status_code)
return
finally:
# The worker incremented it, we'll decrement it.
decrement_ip_count(handler.client_ip, 'processing_ips')
decr_active_workers()
try:
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
except:
# The client closed the stream.
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
ws.close() # this is important if we encountered and error and exited early.
generated_text = generated_text + data['new']
elif data['completed']:
return
except:
send_err_and_quit('Encountered exception while streaming.')
traceback.print_exc()
finally:
try:
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
except:
# The client closed the stream.
pass
if stream_name:
stream_redis.delete(stream_name)
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url
)
finally:
if event_id:
redis.publish(f'notifications:{event_id}', 'canceled')
try:
# Must close the connection or greenlets will complain.
ws.close()
except:
pass

View File

@ -2,22 +2,16 @@ import time
from flask import jsonify, request
from llm_server.custom_redis import flask_cache
from . import bp
from ..auth import requires_auth
from ..cache import flask_cache
from ... import opts
from ...llm.info import get_running_model
from ...cluster.backend import get_backends_from_model, is_valid_model
from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
# @bp.route('/info', methods=['GET'])
# # @cache.cached(timeout=3600, query_string=True)
# def get_info():
# # requests.get()
# return 'yes'
@bp.route('/model', methods=['GET'])
def get_model():
@bp.route('/v1/model', methods=['GET'])
@bp.route('/<model_name>/v1/model', methods=['GET'])
def get_model(model_name=None):
# We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url
@ -26,24 +20,21 @@ def get_model():
if cached_response:
return cached_response
model_name, error = get_running_model()
if not model_name:
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
if not is_valid_model(model_name):
response = jsonify({
'code': 502,
'msg': 'failed to reach backend',
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
'code': 400,
'msg': 'Model does not exist.',
}), 400
else:
num_backends = len(get_backends_from_model(model_name))
response = jsonify({
'result': opts.manual_model_name if opts.manual_model_name else model_name,
'model_backend_count': num_backends,
'timestamp': int(time.time())
}), 200
flask_cache.set(cache_key, response, timeout=60)
return response
@bp.route('/backend', methods=['GET'])
@requires_auth
def get_backend():
return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200

View File

@ -1,8 +1,10 @@
from flask import jsonify
from llm_server.custom_redis import flask_cache
from . import bp
from .generate_stats import generate_stats
from ..cache import flask_cache
from ..auth import requires_auth
from ...cluster.cluster_config import cluster_config, get_backends
from ...helpers import jsonify_pretty
@ -10,3 +12,14 @@ from ...helpers import jsonify_pretty
@flask_cache.cached(timeout=5, query_string=True)
def get_stats():
return jsonify_pretty(generate_stats())
@bp.route('/backends', methods=['GET'])
@requires_auth
def get_backend():
online, offline = get_backends()
result = {}
for i in online + offline:
info = cluster_config.get_backend(i)
result[info['hash']] = info
return jsonify(result), 200

View File

@ -3,6 +3,6 @@ from flask_sock import Sock
sock = Sock()
def init_socketio(app):
def init_wssocket(app):
global sock
sock.init_app(app)

View File

@ -1,35 +0,0 @@
from threading import Thread
from .blocking import start_workers
from .main import main_background_thread
from .moderator import start_moderation_workers
from .printer import console_printer
from .recent import recent_prompters_thread
from .threads import cache_stats
from .. import opts
def start_background():
start_workers(opts.concurrent_gens)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
print('Started the main background thread.')
start_moderation_workers(opts.openai_moderation_workers)
t = Thread(target=cache_stats)
t.daemon = True
t.start()
print('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
print('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
print('Started the console printer.')

View File

@ -1,51 +0,0 @@
import threading
import time
from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
def worker():
while True:
need_to_wait()
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
need_to_wait()
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers()
if not request_json_body:
# This was a dummy request from the websocket handler.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
continue
try:
success, response, error_msg = generator(request_json_body)
event = DataEvent(event_id)
event.set((success, response, error_msg))
finally:
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers()
def start_workers(num_workers: int):
i = 0
for _ in range(num_workers):
t = threading.Thread(target=worker)
t.daemon = True
t.start()
i += 1
print(f'Started {i} inference workers.')
def need_to_wait():
# We need to check the number of active workers since the streaming endpoint may be doing something.
active_workers = redis.get('active_gen_workers', int, 0)
s = time.time()
while active_workers >= opts.concurrent_gens:
time.sleep(0.01)
e = time.time()
if e - s > 0.5:
print(f'Worker was delayed {e - s} seconds.')

View File

@ -0,0 +1,32 @@
import time
from redis import Redis
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
# NOT NEEDED
def cleaner():
r = Redis(db=8)
stream_info = {}
while True:
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
processed_streams = []
for stream in all_streams:
stream = stream.decode()
current_size = r.xlen(stream)
# If the stream is new or its size has changed, update the size and time in the dictionary
if stream not in stream_info or current_size != stream_info[stream]['size']:
stream_info[stream] = {'size': current_size, 'time': time.time()}
processed_streams.append(stream)
else:
# If the size hasn't changed for 5 minutes, delete the stream
if time.time() - stream_info[stream]['time'] >= 300:
r.delete(stream)
print(f"Stream '{stream}' deleted due to inactivity.")
del stream_info[stream]
time.sleep(60)

View File

@ -0,0 +1,160 @@
import json
import threading
import time
import traceback
from uuid import uuid4
import ujson
from redis import Redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis
from llm_server.llm.generator import generator
from llm_server.logging import create_logger
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
stream_redis = Redis(db=8)
STREAM_NAME_PREFIX = 'stream'
def check_cancellation(event, event_id):
"""
This thread checks the pub/sub channel in the background so the main process
isn't bogged down with Redis calls. Otherwise, the main process slows down to 1 token/sec.
:param event:
:param event_id:
:return:
"""
pubsub = redis.pubsub()
pubsub.subscribe(f'notifications:{event_id}')
while not event.is_set():
message = pubsub.get_message()
if message and message['data'] == b'canceled':
event.set()
time.sleep(0.5) # check every half second
def get_stream_name(name: str):
return f'{STREAM_NAME_PREFIX}:{name}'
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
logger = create_logger('inferencer')
prompt = msg_to_backend['prompt']
stream_name = get_stream_name(stream_name)
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
event = threading.Event()
threading.Thread(target=check_cancellation, args=(event, event_id)).start()
try:
response = generator(msg_to_backend, backend_url)
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
# If there is no more data, break the loop
if not chunk:
break
if event.is_set():
logger.debug('Client canceled generation')
response.close()
return
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})})
except AttributeError as e:
if str(e) == "'bool' object has no attribute 'iter_content'":
# We don't care about these errors.
logger.debug('failed to stream from backend - no response')
else:
raise
except Exception as e:
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
raise # We won't handle the exception here.
finally:
# Publish final message to Redis stream
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})})
event.set() # stop the cancellation checking thread
#
def worker(backend_url):
logger = create_logger('inferencer')
status_redis = RedisCustom('worker_status')
worker_id = str(uuid4())
status_redis.setp(str(worker_id), None)
redis_queue = RedisPriorityQueue(backend_url)
while True:
status_redis.setp(str(worker_id), 'waiting...')
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
event = DataEvent(event_id)
try:
backend_info = cluster_config.get_backend(backend_url)
except:
# This is not a critical error because it usually means that the backend is
# offline and this backend is in a state of transition from online to offline.
logger.debug(f'got an exception while getting info for backend {backend_url} - ', traceback.format_exc())
event.set((False, None, 'exception'))
continue
if not backend_info['online']:
event.set((False, None, 'canceled'))
continue
if not selected_model:
selected_model = backend_info['model']
logger.debug(f"Starting using {backend_url} and {selected_model}. Online: {backend_info['online']}. Streaming: {do_stream}")
try:
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers(selected_model, backend_url)
if do_stream:
status_redis.setp(str(worker_id), ('streaming', client_ip))
# Return the name of the stream that the slave should connect to.
event.set((True, get_stream_name(worker_id), None))
msg_to_backend = {
**parameters,
'prompt': request_json_body['prompt'],
'stream': True,
}
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
else:
# Normal inference (not streaming).
status_redis.setp(str(worker_id), ('generating', client_ip))
success, response, error_msg = generator(request_json_body, backend_url)
event.set((success, response, error_msg))
except:
logger.error(traceback.format_exc())
event.set((False, None, 'exception'))
finally:
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), None)
def start_workers(cluster: dict):
logger = create_logger('inferencer')
i = 0
for item in cluster:
for _ in range(item['concurrent_gens']):
t = threading.Thread(target=worker, args=(item['backend_url'],))
t.daemon = True
t.start()
i += 1
logger.info(f'Started {i} inference workers.')

View File

@ -0,0 +1,31 @@
import pickle
import traceback
import redis
from llm_server.database.database import do_db_log
def db_logger():
"""
We don't want the logging operation to be blocking, so we will use a background worker
to do the logging.
:return:
"""
r = redis.Redis(host='localhost', port=6379, db=3)
p = r.pubsub()
p.subscribe('database-logger')
for message in p.listen():
try:
if message['type'] == 'message':
data = pickle.loads(message['data'])
function_name = data['function']
args = data['args']
kwargs = data['kwargs']
if function_name == 'log_prompt':
do_db_log(*args, **kwargs)
except:
traceback.print_exc()

View File

@ -1,56 +0,0 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis
def main_background_thread():
redis.set('average_generation_elapsed_sec', 0)
redis.set('estimated_avg_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
while True:
# TODO: unify this
if opts.mode == 'oobabooga':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
elif opts.mode == 'vllm':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
redis.set('estimated_avg_tps', estimated_avg_tps)
time.sleep(60)

View File

@ -0,0 +1,57 @@
import time
import requests
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config, get_backends
from llm_server.custom_redis import redis
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_info
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
def main_background_thread():
while True:
online, offline = get_backends()
for backend_url in online:
backend_info = cluster_config.get_backend(backend_url)
backend_mode = backend_info['mode']
backend_info = get_info(backend_url, backend_mode)
running_model = backend_info.get('model')
if not running_model:
continue
average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode)
if average_generation_elapsed_sec: # returns None on exception
cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec)
if average_output_tokens:
cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens)
if average_generation_elapsed_sec and average_output_tokens:
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
if opts.background_homepage_cacher:
try:
base_client_api = redis.get('base_client_api', dtype=str)
r = requests.get('https://' + base_client_api, timeout=5)
except Exception as e:
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
backends = priority_queue.get_backends()
for backend_url in backends:
queue = RedisPriorityQueue(backend_url)
queue.cleanup()
time.sleep(30)
def calc_stats_for_backend(backend_url, running_model, backend_mode):
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps

View File

@ -1,10 +1,13 @@
import json
import threading
import time
import traceback
import redis as redis_redis
from llm_server import opts
from llm_server.llm.openai.moderation import check_moderation_endpoint
from llm_server.logging import create_logger
redis_moderation = redis_redis.Redis()
@ -16,36 +19,43 @@ def start_moderation_workers(num_workers):
t.daemon = True
t.start()
i += 1
print(f'Started {i} moderation workers.')
def moderation_worker():
while True:
result = redis_moderation.blpop('queue:msgs_to_check')
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
print(result)
traceback.print_exc()
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
# TODO: don't use UUID tags to identify items. Use native redis
def get_results(tag, num_tasks):
tag = str(tag) # Required for comparison with Redis results.
tag = str(tag) # Cast a UUID4 to a string.
flagged_categories = set()
num_results = 0
start_time = time.time()
while num_results < num_tasks:
result = redis_moderation.blpop('queue:flagged_categories')
result = redis_moderation.blpop(['queue:flagged_categories'], timeout=opts.openai_moderation_timeout)
if result is None:
break # Timeout occurred, break the loop.
result_tag, categories = json.loads(result[1])
if result_tag == tag:
if categories:
for item in categories:
flagged_categories.add(item)
num_results += 1
if time.time() - start_time > opts.openai_moderation_timeout:
logger.warning('Timed out waiting for result from moderator')
break
return list(flagged_categories)
def moderation_worker():
logger = create_logger('moderator')
while True:
result = redis_moderation.blpop(['queue:msgs_to_check'])
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
logger.error(traceback.format_exc())
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))

View File

@ -1,25 +1,34 @@
import logging
import time
import traceback
from llm_server.routes.cache import redis
from llm_server.cluster.backend import get_running_models
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.logging import create_logger
from llm_server.routes.queue import priority_queue
logger = logging.getLogger('console_printer')
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
def console_printer():
logger = create_logger('console_printer')
time.sleep(3)
while True:
processing = redis.hkeys('processing_ips')
processing_count = 0
for ip in processing:
processing_count += int(redis.hget('processing_ips', ip))
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}')
try:
processing = redis.keys('active_gen_workers:http*') # backends always start with http
processing_count = 0
if len(processing):
for k in processing:
processing_count += redis.get(k, default=0, dtype=int)
backends = [k for k, v in cluster_config.all().items() if v['online']]
activity = priority_queue.activity()
# Calculate the queue size the same way it's done on the stats.
queue_size = 0
running_models = get_running_models()
for model in running_models:
queue_size += priority_queue.len(model)
# Active Workers and Processing should read the same. If not, that's an issue.
logger.info(f'Active Workers: {len([i for i in activity if (i[1] and i[1] != "waiting...")])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {len(backends)}')
except:
logger.error(traceback.format_exc())
time.sleep(10)

View File

@ -1,6 +1,6 @@
import time
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def recent_prompters_thread():

View File

@ -0,0 +1,58 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.cluster.worker import cluster_worker
from llm_server.logging import create_logger
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers
from llm_server.workers.logger import db_logger
from llm_server.workers.mainer import main_background_thread
from llm_server.workers.moderator import start_moderation_workers
from llm_server.workers.printer import console_printer
from llm_server.workers.recenter import recent_prompters_thread
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)
def start_background():
logger = create_logger('threader')
start_workers(opts.cluster)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
logger.info('Started the main background thread.')
num_moderators = opts.cluster_workers * 3
start_moderation_workers(num_moderators)
logger.info(f'Started {num_moderators} moderation workers.')
t = Thread(target=cache_stats)
t.daemon = True
t.start()
logger.info('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
logger.info('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
logger.info('Started the console logger.infoer.')
t = Thread(target=cluster_worker)
t.daemon = True
t.start()
logger.info('Started the cluster worker.')
t = Thread(target=db_logger)
t.daemon = True
t.start()
logger.info('Started background logger.')

View File

@ -1,9 +0,0 @@
import time
from llm_server.routes.v1.generate_stats import generate_stats
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)

103
other/gradio/gradio_chat.py Normal file
View File

@ -0,0 +1,103 @@
import os
import sys
import time
import traceback
import warnings
from threading import Thread
import gradio as gr
import openai
import requests
warnings.filterwarnings("ignore")
API_BASE = os.getenv('API_BASE')
if not API_BASE:
print('Must set the secret variable API_BASE to your https://your-site/api')
sys.exit(1)
API_BASE = API_BASE.strip('/')
APP_TITLE = os.getenv('APP_TITLE')
PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE')
TRACKING_CODE = os.getenv('TRACKING_CODE')
def background():
while True:
previous = openai.api_base
try:
r = requests.get(API_BASE + '/stats').json()
if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys():
openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1'
else:
openai.api_base = API_BASE + '/openai/v1'
except:
traceback.print_exc()
openai.api_base = API_BASE + '/openai/v1'
if openai.api_base != previous:
print('Set primary model to', openai.api_base)
time.sleep(10)
if PRIMARY_MODEL_CHOICE:
t = Thread(target=background)
t.daemon = True
t.start()
print('Started the background thread.')
# A system prompt can be injected into the very first spot in the context.
# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE,
# the content in CONTEXT_TRIGGER_INJECTION will be injected.
# Setting CONTEXT_TRIGGER_PHRASE will also add it to the selectable examples.
CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE')
CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION')
openai.api_key = 'null'
openai.api_base = API_BASE + '/openai/v1'
def stream_response(prompt, history):
messages = []
do_injection = False
for human, assistant in history:
messages.append({'role': 'user', 'content': str(human)})
messages.append({'role': 'assistant', 'content': str(assistant)})
if CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in human:
do_injection = True
messages.append({'role': 'user', 'content': prompt})
if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt):
messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION})
try:
response = openai.ChatCompletion.create(
model='0',
messages=messages,
temperature=0,
max_tokens=300,
stream=True,
headers={'LLM-Source': 'huggingface-demo'}
)
except Exception:
raise gr.Error("Failed to reach inference endpoint.")
message = ''
for chunk in response:
if len(chunk['choices'][0]['delta']) != 0:
message += chunk['choices'][0]['delta']['content']
yield message
examples = ["hello"]
if CONTEXT_TRIGGER_PHRASE:
examples.insert(0, CONTEXT_TRIGGER_PHRASE)
with gr.Blocks(analytics_enabled=False) as demo:
gr.ChatInterface(stream_response, examples=examples, title=APP_TITLE, analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}')
if TRACKING_CODE:
print('Inserting tracking code')
gr.HTML(TRACKING_CODE)
demo.queue(concurrency_count=1, api_open=False).launch(show_api=False)

View File

@ -0,0 +1,3 @@
gradio
openai
requests

View File

@ -1,3 +1,8 @@
"""
This file is used to run certain tasks when the HTTP server starts.
It's located here so it doesn't get imported with daemon.py
"""
try:
import gevent.monkey

38
other/nginx-site.conf Normal file
View File

@ -0,0 +1,38 @@
server
{
listen 443 ssl http2 default_server;
server_name _;
proxy_set_header Host $host;
proxy_set_header Connection $http_connection;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Scheme $scheme;
location ~* ^/api/(.*?|v1|openai)/(v1|(generate|stream)|(chat/completions|completions))$
{
# Route to inference endpoints
proxy_pass http://127.0.0.1:5000;
# Required for streaming (both websockets and SSE).
proxy_buffering off;
proxy_cache off;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
# Set long timeouts for inference operations.
# Cloudflare has a timeout of 100 seconds.
proxy_read_timeout 120;
proxy_connect_timeout 120;
proxy_send_timeout 120;
}
location /
{
proxy_pass http://127.0.0.1:5000;
}
ssl_certificate /etc/ssl/certs/nginx-selfsigned.crt;
ssl_certificate_key /etc/ssl/private/nginx-selfsigned.key;
include /etc/nginx/snippets/ssl-params.conf;
}

11
other/tests/config.sh Normal file
View File

@ -0,0 +1,11 @@
HOST="proxy.chub-archive.evulid.cc"
AUTH_KEY="TEST_1df979f0-6df1-41bd-814a-e99b1680e727"
PROXY_SERVERS=(
"http://172.0.4.7:3128"
"http://172.0.4.8:3128"
"http://172.0.4.10:3128"
"http://172.0.4.12:3128"
"http://172.0.4.13:3128"
)

58
other/tests/generate.sh Executable file
View File

@ -0,0 +1,58 @@
#!/bin/bash
SLEEP_TIME=2
while getopts p:t: flag; do
case "${flag}" in
p) PROXY_CHOICE=${OPTARG} ;;
t) SLEEP_TIME=${OPTARG} ;;
*) ;;
esac
done
SOURCE=${BASH_SOURCE[0]}
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
SOURCE=$(readlink "$SOURCE")
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
source "$DIR/config.sh"
if [ -n "$PROXY_CHOICE" ]; then
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
echo "Using $our_proxy_server"
else
our_proxy_server=""
fi
while true; do
echo "--> START <--"
DATA=$(
cat <<EOF
{
"prompt": "Write a 300 word story about an apple tree.",
"temperature": 1,
"max_new_tokens": 100,
"top_p": 1.0,
"top_k": -1,
"use_beam_search": false,
"stop": ["TEST"],
"ignore_eos": false,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"length_penalty": 1.0,
"early_stopping": false
}
EOF
)
curl "https://$HOST/api/v1/generate" -m 100 -x "$our_proxy_server" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $AUTH_KEY" \
-d "$DATA"
echo -e "--> DONE <--\n"
sleep $SLEEP_TIME
done

View File

@ -0,0 +1,52 @@
#!/bin/bash
DO_STREAM=false
SLEEP_TIME=2
while getopts p:t:s flag; do
case "${flag}" in
s) DO_STREAM=true ;;
p) PROXY_CHOICE=${OPTARG} ;;
t) SLEEP_TIME=${OPTARG} ;;
*) ;;
esac
done
SOURCE=${BASH_SOURCE[0]}
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
SOURCE=$(readlink "$SOURCE")
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
source "$DIR/config.sh"
if [ ! -z "$PROXY_CHOICE" ]; then
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
echo "Using $our_proxy_server"
else
our_proxy_server=""
fi
while true; do
echo "--> START <--"
DATA=$(
cat <<EOF
{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Write a 300 word story about an apple tree."}],
"max_tokens": 100,
"stream": $DO_STREAM
}
EOF
)
curl "https://$HOST/api/openai/v1/chat/completions" -m 100 -x "$our_proxy_server" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $AUTH_KEY" \
-d "$DATA"
echo -e "--> DONE <--\n"
sleep $SLEEP_TIME
done

52
other/tests/oai-completion.sh Executable file
View File

@ -0,0 +1,52 @@
#!/bin/bash
DO_STREAM=false
SLEEP_TIME=2
while getopts p:t:s flag; do
case "${flag}" in
s) DO_STREAM=true ;;
p) PROXY_CHOICE=${OPTARG} ;;
t) SLEEP_TIME=${OPTARG} ;;
*) ;;
esac
done
SOURCE=${BASH_SOURCE[0]}
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
SOURCE=$(readlink "$SOURCE")
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
source "$DIR/config.sh"
if [ ! -z "$PROXY_CHOICE" ]; then
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
echo "Using $our_proxy_server"
else
our_proxy_server=""
fi
while true; do
echo "--> START <--"
DATA=$(
cat <<EOF
{
"model": "gpt-4",
"prompt": "Write a 300 word story about an apple tree.",
"max_tokens": 100,
"stream": $DO_STREAM
}
EOF
)
curl "https://$HOST/api/openai/v1/completions" -m 100 -x "$our_proxy_server" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $AUTH_KEY" \
-d "$DATA"
echo -e "--> DONE <--\n"
sleep $SLEEP_TIME
done

64
other/tests/start-bulk.sh Executable file
View File

@ -0,0 +1,64 @@
#!/bin/bash
# Function to display help message
function display_help {
echo "Usage: $0 -n num_windows -c command"
echo
echo " -n, --number Number of windows to create"
echo " -c, --command Command to run in each window"
echo
exit 1
}
# Parse command line arguments
while getopts "n:c:h" opt; do
case ${opt} in
n)
num_windows=${OPTARG}
;;
c)
command=${OPTARG}
;;
h)
display_help
;;
\?)
echo "Invalid option: -$OPTARG" 1>&2
display_help
;;
:)
echo "Option -$OPTARG requires an argument." 1>&2
display_help
;;
esac
done
# Check if number of windows and command are provided
if [ -z "$num_windows" ] || [ -z "$command" ]; then
echo "Both number of windows and command are required."
display_help
fi
# Calculate rows and columns
rows=$(echo "sqrt($num_windows)" | bc)
columns=$(echo "($num_windows + $rows - 1) / $rows" | bc)
# Create a new tmux session
tmux new-session -d -s llm_tester "$command -p 0"
# Create the remaining windows
for ((i = 1; i < $num_windows; i++)); do
if ((i % $columns == 0)); then
tmux select-layout -t llm_tester:0 tiled
tmux select-pane -t 0
tmux split-window -t llm_tester:0 -v "$command -p $i"
else
tmux split-window -t llm_tester:0 -h "$command -p $i"
fi
done
# Balance the windows
tmux select-layout -t llm_tester:0 tiled
# Attach to the session
tmux attach-session -t llm_tester

62
other/ooba-test-streaming.py → other/tests/stream.py Normal file → Executable file
View File

@ -1,37 +1,50 @@
import asyncio
import json
import sys
import os
import time
from pathlib import Path
try:
import websockets
except ImportError:
print("Websockets package not found. Make sure it's installed.")
# For local streaming, the websockets are hosted without ssl - ws://
HOST = 'localhost:5000'
URI = f'ws://{HOST}/api/v1/stream'
script_path = os.path.dirname(os.path.realpath(__file__))
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
def parse_bash_config(file_path):
config = {}
with open(file_path, 'r') as f:
for line in f:
if line.startswith('#') or '=' not in line:
continue
key, value = line.strip().split('=', 1)
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
elif value.startswith('(') and value.endswith(')'):
value = value[1:-1].split()
value = [v.strip('"') for v in value]
config[key] = value
return config
config = parse_bash_config(Path(script_path, 'config.sh'))
async def run(context):
# Note: the selected defaults change from time to time.
request = {
'prompt': context,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
@ -48,7 +61,6 @@ async def run(context):
'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
@ -58,7 +70,7 @@ async def run(context):
'stopping_strings': []
}
async with websockets.connect(URI, ping_interval=None) as websocket:
async with websockets.connect(f'wss://{config["HOST"]}/api/v1/stream', ping_interval=None) as websocket:
await websocket.send(json.dumps(request))
yield context # Remove this if you just want to see the reply
@ -67,20 +79,28 @@ async def run(context):
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
print(incoming_data)
match incoming_data['event']:
case 'text_stream':
yield incoming_data['text']
# case 'text_stream':
# yield incoming_data['text']
case 'stream_end':
return
async def print_response_stream(prompt):
async for response in run(prompt):
print(response, end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
print('\n\nfinished')
try:
async for response in run(prompt):
print(response, end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
except Exception as e:
print(e)
if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)"
asyncio.run(print_response_stream(prompt))
prompt = "Write a 300 word story about an apple tree.\n\n"
while True:
print('--> START <--')
asyncio.run(print_response_stream(prompt))
print('--> DONE <--')
time.sleep(2)

View File

@ -1,21 +1,18 @@
flask~=2.3.3
flask_cors
pyyaml~=6.0.1
flask_caching
Flask-Caching==2.0.2
requests~=2.31.0
tiktoken~=0.5.0
gunicorn
gevent~=23.9.0.post1
async-timeout
flask-sock
uvicorn~=0.23.2
fastapi~=0.103.1
torch~=2.0.1
PyMySQL~=1.1.0
DBUtils~=3.0.3
simplejson~=3.19.1
websockets~=11.0.3
basicauth~=1.0.0
openai~=0.28.0
urllib3~=2.0.4
celery[redis]
flask-sock==0.6.0
gunicorn==21.2.0
redis==5.0.1
ujson==5.8.0
vllm==0.2.1.post1
gradio~=3.46.1
coloredlogs~=15.0.1

138
server.py
View File

@ -1,5 +1,3 @@
from llm_server.config.config import mode_ui_names
try:
import gevent.monkey
@ -7,37 +5,46 @@ try:
except ImportError:
pass
from llm_server.pre_fork import server_startup
from llm_server.config.load import load_config
import os
import sys
from pathlib import Path
import simplejson as json
from flask import Flask, jsonify, render_template, request
from flask import Flask, jsonify, render_template, request, Response
import llm_server
import config
from llm_server import opts
from llm_server.cluster.backend import get_model_choices
from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.config import mode_ui_names
from llm_server.config.load import load_config
from llm_server.custom_redis import flask_cache, redis
from llm_server.database.conn import database
from llm_server.database.create import create_db
from llm_server.llm import get_token_count
from llm_server.routes.openai import openai_bp
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.pre_fork import server_startup
from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.sock import init_wssocket
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
# TODO: implement background thread to test backends via sending test prompts
# TODO: if backend fails request, mark it as down
# TODO: allow setting concurrent gens per-backend
# TODO: set the max tokens to that of the lowest backend
# TODO: implement RRD backend loadbalancer option
# TODO: have VLLM reject a request if it already has n == concurrent_gens running
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: use coloredlogs
# TODO: need to update opts. for workers
# TODO: seperate queue item timeout for websockets (make longer, like 5 minutes)
# TODO: return an `error: True`, error code, and error message rather than just a formatted message
# TODO: what happens when all backends are offline? What about the "online" key in the stats page?
# TODO: redis SCAN vs KEYS??
# TODO: is frequency penalty the same as ooba repetition penalty???
# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
# Lower priority
# TODO: if a backend is at its limit of concurrent requests, choose a different one
# TODO: make error messages consitient
# TODO: support logit_bias on OpenAI and Ooba endpoints.
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: validate openai_silent_trim works as expected and only when enabled
# TODO: rewrite config storage. Store in redis so we can reload it.
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
# TODO: the estiamted wait time lags behind the stats
# TODO: simulate OpenAI error messages regardless of endpoint
@ -59,19 +66,16 @@ except ModuleNotFoundError as e:
print('Please see README.md for install instructions.')
sys.exit(1)
import config
from llm_server import opts
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import RedisWrapper, flask_cache
from llm_server.llm import redis
from llm_server.routes.stats import get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats
app = Flask(__name__)
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
# Fixes ConcurrentObjectUseError
# https://github.com/miguelgrinberg/simple-websocket/issues/24
app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
app.register_blueprint(bp, url_prefix='/api/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
init_wssocket(app)
flask_cache.init_app(app)
flask_cache.clear()
@ -82,18 +86,13 @@ if config_path_environ:
else:
config_path = Path(script_path, 'config', 'config.yml')
success, config, msg = load_config(config_path, script_path)
success, config, msg = load_config(config_path)
if not success:
print('Failed to load config:', msg)
sys.exit(1)
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db()
llm_server.llm.redis = RedisWrapper('local_llm')
create_db()
# print(app.url_map)
@app.route('/')
@ -101,20 +100,30 @@ create_db()
@app.route('/api/openai')
@flask_cache.cached(timeout=10)
def home():
base_client_api = redis.get('base_client_api', dtype=str)
stats = generate_stats()
model_choices, default_model = get_model_choices()
if not stats['online']:
running_model = estimated_wait_sec = 'offline'
else:
running_model = redis.get('running_model', str, 'ERROR')
if default_model:
if not model_choices.get(default_model):
return 'The server is still starting up. Please wait...'
active_gen_workers = get_active_gen_workers()
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
default_model_info = model_choices[default_model]
if default_model_info['queued'] == 0 and default_model_info['queued'] >= default_model_info['concurrent_gens']:
# There will be a wait if the queue is empty but prompts are processing, but we don't
# know how long.
estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"
default_estimated_wait_sec = f"less than {int(default_model_info['estimated_wait'])} seconds"
else:
estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds"
default_estimated_wait_sec = f"{int(default_model_info['estimated_wait'])} seconds"
else:
default_model_info = {
'model': 'OFFLINE',
'processing': '-',
'queued': '-',
'context_size': '-',
}
default_estimated_wait_sec = 'OFFLINE'
if len(config['analytics_tracking_code']):
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
@ -127,32 +136,47 @@ def home():
info_html = ''
mode_info = ''
if opts.mode == 'vllm':
mode_info = vllm_info
base_client_api = redis.get('base_client_api', str)
for k, v in cluster_config.all().items():
if v['mode'] == 'vllm':
mode_info = vllm_info
break
return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name,
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
default_model=default_model_info['model'],
default_active_gen_workers=default_model_info['processing'],
default_proompters_in_queue=default_model_info['queued'],
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
client_api=f'https://{base_client_api}',
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
estimated_wait=estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1],
streaming_input_textbox=mode_ui_names[opts.mode][2],
context_size=opts.context_size,
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled',
default_estimated_wait=default_estimated_wait_sec,
mode_name=mode_ui_names[opts.frontend_api_mode][0],
api_input_textbox=mode_ui_names[opts.frontend_api_mode][1],
streaming_input_textbox=mode_ui_names[opts.frontend_api_mode][2],
default_context_size=default_model_info['context_size'],
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info,
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
expose_openai_system_prompt=opts.expose_openai_system_prompt,
enable_streaming=opts.enable_streaming,
model_choices=model_choices,
proompters_5_min=stats['stats']['proompters']['5_min'],
proompters_24_hrs=stats['stats']['proompters']['24_hrs'],
)
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
@app.route('/robots.txt')
def robots():
# TODO: have config value to deny all
# TODO: https://developers.google.com/search/docs/crawling-indexing/robots/create-robots-txt
t = """User-agent: *
Allow: /"""
r = Response(t)
r.headers['Content-Type'] = 'text/plain'
return r
@app.route('/<first>')
@app.route('/<first>/<path:rest>')

View File

@ -65,6 +65,19 @@
.hidden {
display: none;
}
.header-workers {
font-weight: normal;
font-size: 14pt;
}
h3 {
font-size: 16pt;
}
.no-marker {
list-style: none;
}
</style>
</head>
@ -76,8 +89,12 @@
<h1 style="text-align: center;margin-top: 0;">{{ llm_middleware_name }}</h1>
<div class="info-box">
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
<p><strong>Current Model:</strong> <span id="model">{{ default_model }}</span></p>
<p>
<strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ default_estimated_wait }}</span><br>
Processing: {{ default_active_gen_workers }}<br>
Queued: {{ default_proompters_in_queue }}
</p>
<br>
<p><strong>Client API URL:</strong> {{ client_api }}</p>
<p><strong>Streaming API URL:</strong> {{ ws_client_api if enable_streaming else 'Disabled' }}</p>
@ -91,17 +108,20 @@
<br>
<div class="info-box">
<div id="oobabooga">
<strong>Instructions:</strong>
<h3>Instructions</h3>
<div id="instructions">
<ol>
<li>In Settings > Power User Options, enable <kbd>Relaxed API URLS</kbd>.</li>
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
{% if enable_streaming %}<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>{% endif %}
{% if enable_streaming %}
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
{% endif %}
<li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
API key</kbd> textbox.
</li>
<li>Click <kbd>Connect</kbd> to test the connection.</li>
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ default_context_size }}.</li>
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>
</li>
</ol>
@ -120,13 +140,45 @@
<br>
<div class="info-box">
<pre><code class="language-json" style="background-color: white">{{ stats_json|safe }}</code></pre>
<h3>Statistics</h3>
Proompters:
<ul style="margin-top: 5px;">
<li class="no-marker">5 minutes: {{ proompters_5_min }}</li>
<li class="no-marker">24 hours: {{ proompters_24_hrs }}</li>
</ul>
</div>
<br>
{% for key, value in model_choices.items() %}
<div class="info-box">
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}worker{% else %}workers{% endif %}</span></h3>
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
{% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %}
{% else %}
{% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %}
{% endif %}
<p>
<strong>Estimated Wait Time:</strong> {{ estimated_wait_sec }}<br>
Processing: {{ value.processing }}<br>
Queued: {{ value.queued }}<br>
</p>
<p>
<strong>Client API URL:</strong> {{ value.client_api }}<br>
<strong>Streaming API URL:</strong> {{ value.ws_client_api }}<br>
<strong>OpenAI-Compatible API URL:</strong> {{ value.openai_client_api }}
</p>
<p><strong>Context Size:</strong> {{ value.context_size }}</p>
<p><strong>Average Generation Time:</strong> {{ value.avg_generation_time | int }} seconds</p>
</div>
<br>
{% endfor %}
</div>
<div class="footer">
<a href="https://git.evulid.cc/cyberes/local-llm-server" target="_blank">git.evulid.cc/cyberes/local-llm-server</a>
</div>
<script>hljs.highlightAll();</script>
</body>
</html>