Merge cluster to master #3

Merged
cyberes merged 163 commits from cluster into master 2023-10-27 19:19:22 -06:00
91 changed files with 3139 additions and 1351 deletions

View File

@ -43,7 +43,9 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
### Use
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database.
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
### To Do

View File

@ -1,22 +1,19 @@
import time
from llm_server.routes.cache import redis
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import argparse
import logging
import os
import sys
import time
from pathlib import Path
from llm_server.config.load import load_config
from llm_server.database.create import create_db
from redis import Redis
from llm_server.workers.app import start_background
from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.load import load_config, parse_backends
from llm_server.custom_redis import redis
from llm_server.database.create import create_db
from llm_server.logging import create_logger, logging_info, init_logging
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
@ -26,19 +23,46 @@ else:
config_path = Path(script_path, 'config', 'config.yml')
if __name__ == "__main__":
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
parser = argparse.ArgumentParser(description='Daemon microservice.')
parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.")
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
args = parser.parse_args()
success, config, msg = load_config(config_path, script_path)
# TODO: have this be set by either the arg or a config value
if args.debug:
logging_info.level = logging.DEBUG
init_logging()
logger = create_logger('daemon')
logger.debug('Debug logging enabled.')
if not args.no_reset:
Redis().flushall()
logger.info('Flushed Redis.')
success, config, msg = load_config(config_path)
if not success:
print('Failed to load config:', msg)
logger.info(f'Failed to load config: {msg}')
sys.exit(1)
create_db()
cluster_config.clear()
cluster_config.load(parse_backends(config))
logger.info('Loading backend stats...')
generate_stats(regen=True)
start_background()
redis.set('daemon_started', 1)
print('== Daemon Setup Complete ==\n')
# Give some time for the background threads to get themselves ready to go.
time.sleep(2)
while True:
time.sleep(3600)
redis.set('daemon_started', 1)
logger.info('== Daemon Setup Complete ==')
try:
while True:
time.sleep(3600)
except KeyboardInterrupt:
redis.set('daemon_started', 0)

View File

View File

@ -0,0 +1,117 @@
import numpy as np
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
from llm_server.cluster.stores import redis_running_models
from llm_server.custom_redis import redis
from llm_server.llm.generator import generator
from llm_server.llm.info import get_info
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model
def get_backends_from_model(model_name: str):
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
def get_running_models():
return redis_running_models.keys()
def purge_backend_from_running_models(backend_url: str):
keys = redis_running_models.keys()
pipeline = redis_running_models.pipeline()
for model in keys:
pipeline.srem(model, backend_url)
pipeline.execute()
def is_valid_model(model_name: str):
return redis_running_models.exists(model_name)
def test_backend(backend_url: str, test_prompt: bool = False):
backend_info = cluster_config.get_backend(backend_url)
if test_prompt:
handler = VLLMBackend(backend_url)
parameters, _ = handler.get_parameters({
"stream": False,
"temperature": 0,
"max_new_tokens": 3,
})
data = {
'prompt': 'test prompt',
**parameters
}
try:
success, response, err = generator(data, backend_url, timeout=10)
if not success or not response or err:
return False, {}
except:
return False, {}
i = get_info(backend_url, backend_info['mode'])
if not i.get('model'):
return False, {}
return True, i
def get_model_choices(regen: bool = False):
if not regen:
c = redis.getp('model_choices')
if c:
return c
base_client_api = redis.get('base_client_api', dtype=str)
running_models = get_running_models()
model_choices = {}
for model in running_models:
b = get_backends_from_model(model)
context_size = []
avg_gen_per_worker = []
concurrent_gens = 0
for backend_url in b:
backend_info = cluster_config.get_backend(backend_url)
if backend_info.get('model_config'):
context_size.append(backend_info['model_config']['max_position_embeddings'])
if backend_info.get('average_generation_elapsed_sec'):
avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec'])
concurrent_gens += backend_info['concurrent_gens']
active_gen_workers = get_active_gen_workers_model(model)
proompters_in_queue = priority_queue.len(model)
if len(avg_gen_per_worker):
average_generation_elapsed_sec = np.average(avg_gen_per_worker)
else:
average_generation_elapsed_sec = 0
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, concurrent_gens, active_gen_workers)
model_choices[model] = {
'model': model,
'client_api': f'https://{base_client_api}/{model}',
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled',
'backend_count': len(b),
'estimated_wait': estimated_wait_sec,
'queued': proompters_in_queue,
'processing': active_gen_workers,
'avg_generation_time': average_generation_elapsed_sec,
'concurrent_gens': concurrent_gens
}
if len(context_size):
model_choices[model]['context_size'] = min(context_size)
# Python wants to sort lowercase vs. uppercase letters differently.
model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper()))
default_backend_url = get_a_cluster_backend()
default_backend_info = cluster_config.get_backend(default_backend_url)
if not default_backend_info.get('model'):
return {}, None
default_model = default_backend_info['model']
redis.setp('model_choices', (model_choices, default_model))
return model_choices, default_model

View File

@ -0,0 +1,124 @@
import hashlib
import pickle
import traceback
from llm_server import opts
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
from llm_server.cluster.stores import redis_running_models
from llm_server.custom_redis import RedisCustom
from llm_server.routes.helpers.model import estimate_model_size
class RedisClusterStore:
def __init__(self, name: str, **kwargs):
self.name = name
self.config_redis = RedisCustom(name, **kwargs)
def clear(self):
self.config_redis.flush()
def load(self, config: dict):
for k, v in config.items():
self.add_backend(k, v)
def add_backend(self, name: str, values: dict):
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
self.set_backend_value(name, 'online', False)
h = hashlib.sha256(name.encode('utf-8')).hexdigest()
self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}')
def set_backend_value(self, backend: str, key: str, value):
# By storing the value as a pickle we don't have to cast anything when getting the value from Redis.
self.config_redis.hset(backend, key, pickle.dumps(value))
def get_backend(self, name: str):
r = self.config_redis.hgetall(name)
output = {}
for k, v in r.items():
output[k.decode('utf8')] = pickle.loads(v)
return output
def all(self):
keys = self.config_redis.keys('*')
if keys:
result = {}
for key in keys:
if key != f'{self.name}:____':
v = self.get_backend(key)
result[key] = v
return result
else:
return {}
def validate_backend(self, backend_url: str):
"""
Returns the backend URL that was given, or a new one if that was offline.
:param backend_url:
:return:
"""
backend_info = self.get_backend(backend_url)
if not backend_info['online']:
old = backend_url
backend_url = get_a_cluster_backend()
print(f'Backend {old} offline. Request was redirected to {backend_url}')
return backend_url
cluster_config = RedisClusterStore('cluster_config')
def get_backends():
backends = cluster_config.all()
result = {}
for k, v in backends.items():
b = cluster_config.get_backend(k)
status = b.get('online', False)
priority = b['priority']
result[k] = {'status': status, 'priority': priority}
try:
if not opts.prioritize_by_size:
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: -kv[1]['priority'],
reverse=True
)
else:
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: estimate_model_size(kv[1]['model_config']),
reverse=True
)
offline_backends = sorted(
((url, info) for url, info in backends.items() if not info['online']),
key=lambda kv: -kv[1]['priority'],
reverse=True
)
return [url for url, info in online_backends], [url for url, info in offline_backends]
except KeyError:
traceback.print_exc()
print(backends)
def get_a_cluster_backend(model=None):
"""
Get a backend from Redis. If there are no online backends, return None.
If `model` is not supplied, we will pick one ourself.
"""
if model:
# First, determine if there are multiple backends hosting the same model.
backends_hosting_model = [i.decode('utf-8') for i in redis_running_models.smembers(model)]
# If so, create an iterator for those backends
if len(backends_hosting_model):
add_backend_cycler(model, backends_hosting_model)
cycled = redis_cycle(model)
if len(cycled):
return cycled[0]
else:
# No backend hosting that model
return None
else:
online, _ = get_backends()
if len(online):
return online[0]

View File

@ -0,0 +1,39 @@
import redis
redis_cycler_db = redis.Redis(host='localhost', port=6379, db=9)
def redis_cycle(list_name):
"""
Emulates itertools.cycle() but returns the complete shuffled list.
:param list_name:
:return:
"""
pipeline = redis_cycler_db.pipeline()
pipeline.lpop(list_name)
to_move = pipeline.execute()[0]
if not to_move:
return []
pipeline.rpush(list_name, to_move)
pipeline.lrange(list_name, 0, -1)
results = pipeline.execute()
new_list = results[-1]
return [x.decode('utf-8') for x in new_list]
def add_backend_cycler(list_name: str, new_elements: list):
existing_elements = [i.decode('utf-8') for i in redis_cycler_db.lrange(list_name, 0, -1)]
existing_set = set(existing_elements)
with redis_cycler_db.pipeline() as pipe:
# Add elements
for element in new_elements:
if element not in existing_set:
pipe.rpush(list_name, element)
# Remove elements
for element in existing_set:
if element not in new_elements:
pipe.lrem(list_name, 0, element)
pipe.execute()

View File

@ -0,0 +1,3 @@
from llm_server.custom_redis import RedisCustom
redis_running_models = RedisCustom('running_models')

View File

@ -0,0 +1,38 @@
import time
from threading import Thread
from llm_server.cluster.backend import test_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.stores import redis_running_models
def cluster_worker():
counter = 0
while True:
test_prompt = False
if counter % 4 == 0:
# Only send a test prompt every 120 seconds.
test_prompt = True
threads = []
for n, v in cluster_config.all().items():
thread = Thread(target=check_backend, args=(n, v, test_prompt))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
time.sleep(15)
counter += 1
def check_backend(n, v, test_prompt):
online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
if online:
running_model = backend_info['model']
for k, v in backend_info.items():
cluster_config.set_backend_value(n, k, v)
redis_running_models.sadd(running_model, n)
else:
for model in redis_running_models.keys():
redis_running_models.srem(model, n)
cluster_config.set_backend_value(n, 'online', online)

View File

@ -28,16 +28,19 @@ config_default_vars = {
'openai_force_no_hashes': True,
'include_system_tokens_in_stats': True,
'openai_moderation_scan_last_n': 5,
'openai_moderation_workers': 10,
'openai_org_name': 'OpenAI',
'openai_silent_trim': False,
'openai_moderation_enabled': True,
'netdata_root': None
'netdata_root': None,
'show_backends': True,
'background_homepage_cacher': True,
'openai_moderation_timeout': 5,
'prioritize_by_size': False
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name']
mode_ui_names = {
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
}

View File

@ -3,38 +3,28 @@ import sys
import openai
import llm_server
from llm_server import opts
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
from llm_server.custom_redis import redis
from llm_server.database.conn import database
from llm_server.database.database import get_number_of_rows
from llm_server.helpers import resolve_path
from llm_server.routes.cache import redis
from llm_server.routes.queue import PriorityQueue
def load_config(config_path, script_path):
def load_config(config_path):
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
success, config, msg = config_loader.load_config()
if not success:
return success, config, msg
# Resolve relative directory to the directory of the script
if config['database_path'].startswith('./'):
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
if config['mode'] not in ['oobabooga', 'vllm']:
print('Unknown mode:', config['mode'])
sys.exit(1)
# TODO: this is atrocious
opts.mode = config['mode']
opts.auth_required = config['auth_required']
opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client']
opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.backend_url = config['backend_url'].strip('/')
opts.cluster = config['cluster']
opts.show_total_output_tokens = config['show_total_output_tokens']
opts.netdata_root = config['netdata_root']
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
@ -53,10 +43,20 @@ def load_config(config_path, script_path):
opts.openai_force_no_hashes = config['openai_force_no_hashes']
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
opts.openai_moderation_workers = config['openai_moderation_workers']
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
opts.show_backends = config['show_backends']
opts.background_homepage_cacher = config['background_homepage_cacher']
opts.openai_moderation_timeout = config['openai_moderation_timeout']
opts.frontend_api_mode = config['frontend_api_mode']
opts.prioritize_by_size = config['prioritize_by_size']
# Scale the number of workers.
for item in config['cluster']:
opts.cluster_workers += item['concurrent_gens']
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
@ -78,6 +78,16 @@ def load_config(config_path, script_path):
if config['load_num_prompts']:
redis.set('proompts', get_number_of_rows('prompts'))
redis.set('backend_mode', opts.mode)
return success, config, msg
def parse_backends(config):
if not config.get('cluster'):
return False
cluster = config.get('cluster')
config = {}
for item in cluster:
backend_url = item['backend_url'].strip('/')
item['backend_url'] = backend_url
config[backend_url] = item
return config

View File

@ -0,0 +1,3 @@
from llm_server.custom_redis import RedisCustom
redis_config = RedisCustom('redis_config')

View File

@ -1,24 +1,27 @@
import pickle
import sys
import traceback
from typing import Callable, List, Mapping, Union
from typing import Callable, List, Mapping, Optional, Union
import redis as redis_pkg
import simplejson as json
from flask_caching import Cache
from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000
class RedisWrapper:
class RedisCustom(Redis):
"""
A wrapper class to set prefixes to keys.
A simple wrapper class for Redis to create a "namespace" within a DB,
which simplyifies key management.
"""
def __init__(self, prefix, **kwargs):
super().__init__()
self.redis = Redis(**kwargs)
self.prefix = prefix
try:
@ -34,12 +37,11 @@ class RedisWrapper:
def set(self, key, value, ex: Union[ExpiryT, None] = None):
return self.redis.set(self._key(key), value, ex=ex)
def get(self, key, dtype=None, default=None):
"""
:param key:
:param dtype: convert to this type
:return:
"""
def get(self, key, default=None, dtype=None):
# TODO: use pickle
import inspect
if inspect.isclass(default):
raise Exception
d = self.redis.get(self._key(key))
if dtype and d:
@ -108,7 +110,10 @@ class RedisWrapper:
):
return self.redis.hincrby(self._key(name), key, amount)
def hdel(self, name: str, *keys: List):
def zcard(self, name: KeyT):
return self.redis.zcard(self._key(name))
def hdel(self, name: str, *keys: str):
return self.redis.hdel(self._key(name), *keys)
def hget(
@ -129,9 +134,62 @@ class RedisWrapper:
):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
def lpush(self, name: str, *values: FieldT):
return self.redis.lpush(self._key(name), *values)
def hset(
self,
name: str,
key: Optional = None,
value=None,
mapping: Optional[dict] = None,
items: Optional[list] = None,
):
return self.redis.hset(self._key(name), key, value, mapping, items)
def hkeys(self, name: str):
return self.redis.hkeys(self._key(name))
def hmget(self, name: str, keys: List, *args: List):
return self.redis.hmget(self._key(name), keys, *args)
def hgetall(self, name: str):
return self.redis.hgetall(self._key(name))
def keys(self, pattern: PatternT = "*", **kwargs):
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
keys = []
for key in raw_keys:
p = key.decode('utf-8').split(':')
if len(p) >= 2:
# Delete prefix
del p[0]
k = ':'.join(p)
if k != '____':
keys.append(k)
return keys
def pipeline(self, transaction=True, shard_hint=None):
return self.redis.pipeline(transaction, shard_hint)
def smembers(self, name: str):
return self.redis.smembers(self._key(name))
def spop(self, name: str, count: Optional[int] = None):
return self.redis.spop(self._key(name), count)
def rpoplpush(self, src, dst):
return self.redis.rpoplpush(src, dst)
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
return self.redis.zpopmin(self._key(name), count)
def exists(self, *names: KeyT):
n = []
for name in names:
n.append(self._key(name))
return self.redis.exists(*n)
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex)
@ -142,6 +200,15 @@ class RedisWrapper:
else:
return json.loads(r.decode("utf-8"))
def setp(self, name, value):
self.redis.set(self._key(name), pickle.dumps(value))
def getp(self, name: str):
r = self.redis.get(self._key(name))
if r:
return pickle.loads(r)
return r
def flush(self):
flushed = []
for key in self.redis.scan_iter(f'{self.prefix}:*'):
@ -149,5 +216,40 @@ class RedisWrapper:
self.redis.delete(key)
return flushed
def flushall(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
redis = RedisWrapper('local_llm')
def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
def lrange(self, name: str, start: int, end: int):
return self.redis.lrange(self._key(name), start, end)
def delete(self, *names: KeyT):
return self.redis.delete(*[self._key(i) for i in names])
def lpop(self, name: str, count: Optional[int] = None):
return self.redis.lpop(self._key(name), count)
def zrange(
self,
name: KeyT,
start: int,
end: int,
desc: bool = False,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
byscore: bool = False,
bylex: bool = False,
offset: int = None,
num: int = None,
):
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
def zrem(self, name: KeyT, *values: FieldT):
return self.redis.zrem(self._key(name), *values)
redis = RedisCustom('local_llm')

View File

@ -5,20 +5,20 @@ class DatabaseConnection:
host: str = None
username: str = None
password: str = None
database: str = None
database_name: str = None
def init_db(self, host, username, password, database):
def init_db(self, host, username, password, database_name):
self.host = host
self.username = username
self.password = password
self.database = database
self.database_name = database_name
def cursor(self):
db = pymysql.connect(
host=self.host,
user=self.username,
password=self.password,
database=self.database,
database=self.database_name,
charset='utf8mb4',
autocommit=True,
)

View File

@ -1,15 +1,19 @@
import json
import time
import traceback
from typing import Union
import llm_server
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.database.conn import database
from llm_server.llm.vllm import tokenize
from llm_server.routes.cache import redis
from llm_server.llm import get_token_count
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
# Try not to shove JSON into the database.
if isinstance(response, dict) and response.get('results'):
response = response['results'][0]['text']
try:
@ -19,10 +23,11 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
except:
pass
prompt_tokens = llm_server.llm.get_token_count(prompt)
prompt_tokens = get_token_count(prompt, backend_url)
if not is_error:
if not response_tokens:
response_tokens = llm_server.llm.get_token_count(response)
response_tokens = get_token_count(response, backend_url)
else:
response_tokens = None
@ -43,7 +48,9 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
if token:
increment_token_uses(token)
running_model = redis.get('running_model', str, 'ERROR')
backend_info = cluster_config.get_backend(backend_url)
running_model = backend_info.get('model')
backend_mode = backend_info['mode']
timestamp = int(time.time())
cursor = database.cursor()
try:
@ -52,7 +59,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()
@ -179,3 +186,21 @@ def increment_token_uses(token):
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
finally:
cursor.close()
def get_token_ratelimit(token):
priority = 9990
simultaneous_ip = opts.simultaneous_requests_per_ip
if token:
cursor = database.cursor()
try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
result = cursor.fetchone()
if result:
priority, simultaneous_ip = result
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip

View File

@ -0,0 +1,30 @@
import pickle
from typing import Union
from redis import Redis
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
r = Redis(host='localhost', port=6379, db=3)
data = {
'function': 'log_prompt',
'args': [],
'kwargs': {
'ip': ip,
'token': token,
'prompt': prompt,
'response': response,
'gen_time': gen_time,
'parameters': parameters,
'headers': dict(headers) if headers else headers,
'backend_response_code': backend_response_code,
'request_url': request_url,
'backend_url': backend_url,
'response_tokens': response_tokens,
'is_error': is_error
}
}
r.publish('database-logger', pickle.dumps(data))

View File

@ -8,7 +8,7 @@ import simplejson as json
from flask import make_response
from llm_server import opts
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def resolve_path(*p: str):
@ -54,13 +54,14 @@ def jsonify_pretty(json_dict: Union[list, dict], status=200, indent=4, sort_keys
def round_up_base(n, base):
if base == 0:
print('round_up_base DIVIDE BY ZERO ERROR????', n, base)
# TODO: I don't think passing (0, 0) to this function is a sign of any underlying issues.
# print('round_up_base DIVIDE BY ZERO ERROR????', n, base)
return 0
return math.ceil(n / base) * base
def auto_set_base_client_api(request):
http_host = redis.get('http_host', str)
http_host = redis.get('http_host', dtype=str)
host = request.headers.get("Host")
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
# If the current http_host is not an IP, don't do anything.

View File

@ -1,12 +0,0 @@
import threading
class ThreadSafeInteger:
def __init__(self, value=0):
self.value = value
self._value_lock = threading.Lock()
def increment(self):
with self._value_lock:
self.value += 1
return self.value

View File

@ -1,11 +1,30 @@
import tiktoken
from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import oobabooga, vllm
from llm_server.routes.cache import redis
from llm_server.logging import create_logger
def get_token_count(prompt: str):
backend_mode = redis.get('backend_mode', str)
def fallback_tokenizer(prompt: str):
tokenizer = tiktoken.get_encoding("cl100k_base")
return len(tokenizer.encode(prompt)) + 10
def get_token_count(prompt: str, backend_url: str):
backend_url = cluster_config.validate_backend(backend_url)
if not backend_url:
logger = create_logger('tokenizer')
logger.warning('using fallback tokenizer as there is no valid backend')
return fallback_tokenizer(prompt)
backend_mode = cluster_config.get_backend(backend_url).get('mode')
if not backend_mode:
logger = create_logger('tokenizer')
logger.warning("using fallback tokenizer as the backend isn't initalized")
return fallback_tokenizer(prompt)
if backend_mode == 'vllm':
return vllm.tokenize(prompt)
return vllm.tokenize(prompt, backend_url)
elif backend_mode == 'ooba':
return oobabooga.tokenize(prompt)
else:

View File

@ -1,14 +1,15 @@
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
def generator(request_json_body):
if opts.mode == 'oobabooga':
def generator(request_json_body, cluster_backend, timeout: int = None):
mode = cluster_config.get_backend(cluster_backend)['mode']
if mode == 'ooba':
# from .oobabooga.generate import generate
# return generate(request_json_body)
raise NotImplementedError
elif opts.mode == 'vllm':
elif mode == 'vllm':
from .vllm.generate import generate
r = generate(request_json_body)
return r
return generate(request_json_body, cluster_backend, timeout=timeout)
else:
raise Exception

View File

@ -3,23 +3,35 @@ import requests
from llm_server import opts
def get_running_model():
# TODO: cache the results for 1 min so we don't have to keep calling the backend
# TODO: only use one try/catch
if opts.mode == 'oobabooga':
def get_running_model(backend_url: str, mode: str):
if mode == 'ooba':
try:
backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json()
return r_json['result'], None
except Exception as e:
return False, e
elif opts.mode == 'vllm':
elif mode == 'vllm':
try:
backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json()
return r_json['model'], None
except Exception as e:
return False, e
else:
raise Exception
def get_info(backend_url: str, mode: str):
if mode == 'ooba':
return {}
# raise NotImplementedError
elif mode == 'vllm':
try:
r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout)
j = r.json()
except Exception as e:
return {}
return j
else:
raise Exception

View File

@ -2,14 +2,17 @@ from typing import Tuple, Union
import flask
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
class LLMBackend:
_default_params: dict
def __init__(self, backend_url: str):
self.backend_url = backend_url
self.backend_info = cluster_config.get_backend(self.backend_url)
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError
@ -32,14 +35,16 @@ class LLMBackend:
"""
If a backend needs to do other checks not related to the prompt or parameters.
Default is no extra checks preformed.
:param request:
:param prompt:
:param parameters:
:return:
"""
return True, None
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = get_token_count(prompt)
if prompt_len > opts.context_size - 10:
model_name = redis.get('running_model', str, 'NO MODEL ERROR')
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
prompt_len = get_token_count(prompt, self.backend_url)
token_limit = self.backend_info['model_config']['max_position_embeddings']
if prompt_len > token_limit - 10:
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size'
return True, None

View File

@ -1,78 +1,6 @@
from flask import jsonify
from ..llm_backend import LLMBackend
from ...database.database import log_prompt
from ...helpers import safe_list_get
from ...routes.cache import redis
from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json
class OobaboogaBackend(LLMBackend):
default_params = {}
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError('need to implement default_params')
backend_err = False
response_valid_json, response_json_body = validate_json(response)
if response:
try:
# Be extra careful when getting attributes from the response object
response_status_code = response.status_code
except:
response_status_code = 0
else:
response_status_code = None
# ===============================================
# We encountered an error
if not success or not response or error_msg:
if not error_msg or error_msg == '':
error_msg = 'Unknown error.'
else:
error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, 'error')
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': error_msg,
'results': [{'text': backend_response}]
}), 400
# ===============================================
if response_valid_json:
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
# Ooba doesn't return any error messages so we will just tell the client an error occurred
backend_err = True
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
'error')
response_json_body['results'][0]['text'] = backend_response
if not backend_err:
redis.incr('proompts')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({
**response_json_body
}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 400
def validate_params(self, params_dict: dict):
# No validation required
return True, None
def get_parameters(self, parameters):
del parameters['prompt']
return parameters
def __int__(self):
return

View File

@ -10,7 +10,7 @@ def check_moderation_endpoint(prompt: str):
}
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
if response.status_code != 200:
print(response.text)
print('moderation failed:', response)
response.raise_for_status()
response = response.json()

View File

@ -0,0 +1,97 @@
from flask import jsonify
from llm_server import opts
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
if not request_json_body.get('stop'):
request_json_body['stop'] = []
if not isinstance(request_json_body['stop'], list):
# It is a string, so create a list with the existing element.
request_json_body['stop'] = [request_json_body['stop']]
if stop_hashes:
if opts.openai_force_no_hashes:
request_json_body['stop'].append('###')
else:
# TODO: make stopping strings a configurable
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT'])
else:
request_json_body['stop'].extend(['user:', 'assistant:'])
if request_json_body.get('frequency_penalty', 0) < -2:
request_json_body['frequency_penalty'] = -2
elif request_json_body.get('frequency_penalty', 0) > 2:
request_json_body['frequency_penalty'] = 2
if mode == 'vllm' and request_json_body.get('top_p') == 0:
request_json_body['top_p'] = 0.01
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
if request_json_body['max_tokens'] == 0:
# We don't want to set any defaults here.
del request_json_body['max_tokens']
return request_json_body
def format_oai_err(err_msg):
print('OAI ERROR MESSAGE:', err_msg)
return jsonify({
"error": {
"message": err_msg,
"type": "invalid_request_error",
"param": None,
"code": None
}
}), 400
def validate_oai(parameters):
if parameters.get('messages'):
for m in parameters['messages']:
if m['role'].lower() not in ['assistant', 'user', 'system']:
return format_oai_err('messages role must be assistant, user, or system')
if parameters.get('temperature', 0) > 2:
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
if parameters.get('temperature', 0) < 0:
return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('presence_penalty', 1) > 2:
return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
if parameters.get('presence_penalty', 1) < -2:
return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('top_p', 1) > 2:
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
if parameters.get('top_p', 1) < 0:
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
if parameters.get('max_tokens', 2) < 1:
return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'")
def return_invalid_model_err(requested_model: str):
if requested_model:
msg = f"The model `{requested_model}` does not exist"
else:
msg = "The requested model does not exist"
return jsonify({
"error": {
"message": msg,
"type": "invalid_request_error",
"param": None,
"code": "model_not_found"
}
}), 404

View File

@ -2,86 +2,35 @@ import concurrent.futures
import re
import secrets
import string
import time
import traceback
from typing import Dict, List
import tiktoken
from flask import jsonify, make_response
import llm_server
from llm_server import opts
from llm_server.llm import get_token_count
from llm_server.routes.cache import redis
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
def build_openai_response(prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
# y = response.split('\n### ')
# if len(y) > 1:
# response = re.sub(r'\n$', '', y[0].strip(' '))
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR')
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
def generate_oai_string(length=24):
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for i in range(length))
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
tokenizer = tiktoken.get_encoding("cl100k_base")
def get_token_count_tiktoken_thread(msg):
return len(tokenizer.encode(msg["content"]))
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
def get_token_count_thread(msg):
return get_token_count(msg["content"], backend_url)
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
# If total tokens exceed the limit, start trimming
if total_tokens > context_token_limit:
if total_tokens + formatting_tokens > context_token_limit:
while True:
while total_tokens + formatting_tokens > context_token_limit:
# Calculate the index to start removing messages from
@ -94,22 +43,43 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
break
def get_token_count_thread(msg):
return get_token_count(msg["content"])
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
total_tokens = sum(token_counts)
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
if total_tokens + formatting_tokens > context_token_limit:
# Start over, but this time calculate the token count using the backend
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
token_counts = list(executor.map(get_token_count_thread, prompt))
else:
break
return prompt
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
tokenizer = tiktoken.get_encoding("cl100k_base")
token_count = get_token_count(prompt, backend_url)
# If total tokens exceed the limit, start trimming
if token_count > context_token_limit:
while True:
while token_count > context_token_limit:
# Calculate the index to start removing characters from
remove_index = len(prompt) // 3
while remove_index < len(prompt):
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
token_count = len(tokenizer.encode(prompt))
if token_count <= context_token_limit or remove_index == len(prompt):
break
token_count = get_token_count(prompt, backend_url)
if token_count > context_token_limit:
# Start over, but this time calculate the token count using the backend
token_count = get_token_count(prompt, backend_url)
else:
break
return prompt
@ -117,8 +87,9 @@ def transform_messages_to_prompt(oai_messages):
try:
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
for msg in oai_messages:
if not msg.get('content') or not msg.get('role'):
if 'content' not in msg.keys() or 'role' not in msg.keys():
return False
msg['content'] = str(msg['content']) # Prevent any weird issues.
if msg['role'] == 'system':
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
elif msg['role'] == 'user':
@ -126,7 +97,7 @@ def transform_messages_to_prompt(oai_messages):
elif msg['role'] == 'assistant':
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
else:
return False
raise Exception(f'Unknown role: {msg["role"]}')
except Exception as e:
# TODO: use logging
traceback.print_exc()

View File

@ -1,80 +1,21 @@
"""
This file is used by the worker that processes requests.
"""
import json
import time
from uuid import uuid4
import requests
import llm_server
from llm_server import opts
from llm_server.routes.cache import redis
# TODO: make the VLMM backend return TPS and time elapsed
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
def prepare_json(json_data: dict):
# logit_bias is not currently supported
# del json_data['logit_bias']
# Convert back to VLLM.
json_data['max_tokens'] = json_data.pop('max_new_tokens')
return json_data
def transform_to_text(json_request, api_response):
"""
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
:param json_request:
:param api_response:
:return:
"""
prompt = transform_prompt_to_text(json_request['messages'])
text = ''
finish_reason = None
for line in api_response.split('\n'):
if line.startswith('data:'):
try:
data = json.loads(line[5:].strip())
except json.decoder.JSONDecodeError:
break
if 'choices' in data:
for choice in data['choices']:
if 'delta' in choice and 'content' in choice['delta']:
text += choice['delta']['content']
if data['choices'][0]['finish_reason']:
finish_reason = data['choices'][0]['finish_reason']
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text))
running_model = redis.get('running_model', str, 'ERROR')
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
return {
"id": str(uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": running_model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
},
"choices": [
{
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason,
"index": 0
}
]
}
def transform_prompt_to_text(prompt: list):
text = ''
for item in prompt:
@ -82,26 +23,26 @@ def transform_prompt_to_text(prompt: list):
return text.strip('\n')
def handle_blocking_request(json_data: dict):
def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10):
try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
except requests.exceptions.ReadTimeout:
print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
# print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
return False, None, 'Request to backend timed out'
except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False, None, 'Request to backend encountered error'
if r.status_code != 200:
print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
return False, r, f'Backend returned {r.status_code}'
return True, r, None
def generate(json_data: dict):
def generate(json_data: dict, cluster_backend, timeout: int = None):
if json_data.get('stream'):
try:
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False
else:
return handle_blocking_request(json_data)
return handle_blocking_request(json_data, cluster_backend, timeout=timeout)

View File

@ -1,3 +1,7 @@
import requests
from llm_server import opts
vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/vllm-project/vllm" target="_blank">vllm</a> and not all Oobabooga parameters are supported.</p>
<strong>Supported Parameters:</strong>
<ul>
@ -7,4 +11,4 @@ vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="
<li><kbd>max_new_tokens</kbd></li>
<li><kbd>num_beams</kbd> <span style="font-size:9pt">(setting to greater than 1 enables beam search)</span></li>
<li><kbd>ban_eos_token</kbd></li>
</ul>"""
</ul>"""

View File

@ -1,26 +1,51 @@
import concurrent.futures
import requests
import tiktoken
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.logging import create_logger
def tokenize(prompt: str) -> int:
def tokenize(prompt: str, backend_url: str) -> int:
assert backend_url
assert isinstance(backend_url, str)
if not prompt:
# The tokenizers have issues when the prompt is None.
return 0
assert isinstance(prompt, str)
logger = create_logger('tokenizer')
# The backend could have died between when the request was
# submitted and now, so let's double check it's still online.
backend_url = cluster_config.validate_backend(backend_url)
tokenizer = tiktoken.get_encoding("cl100k_base")
# First we tokenize it locally to determine if it's worth sending it to the backend.
initial_estimate = len(tokenizer.encode(prompt))
if initial_estimate <= opts.context_size + 200:
# Split the prompt into 2000 character chunks
chunk_size = 2000
chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)]
# Define a function to send a chunk to the server
def send_chunk(chunk):
try:
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
r = requests.post(f'{backend_url}/tokenize', json={'input': chunk}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
j = r.json()
return j['length']
except Exception as e:
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
return len(tokenizer.encode(prompt)) + 10
else:
# If the result was greater than our context size, return the estimate.
# We won't be sending it through the backend so it does't need to be accurage.
return initial_estimate
logger.debug(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
return len(tokenizer.encode(chunk)) + 10
# Use a ThreadPoolExecutor to send all chunks to the server at once
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_chunk = {executor.submit(send_chunk, chunk): chunk for chunk in chunks}
for future in concurrent.futures.as_completed(future_to_chunk):
chunk = future_to_chunk[future]
try:
data = future.result()
except Exception as exc:
logger.warning('%r generated an exception: %s' % (chunk, exc))
return sum(future.result() for future in future_to_chunk)

View File

@ -1,10 +1,9 @@
import threading
from typing import Tuple
from flask import jsonify
from vllm import SamplingParams
from llm_server.database.database import log_prompt
from llm_server.database.log_to_db import log_to_db
from llm_server.llm.llm_backend import LLMBackend
@ -19,16 +18,8 @@ class VLLMBackend(LLMBackend):
# Failsafe
backend_response = ''
r_url = request.url
def background_task():
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
return jsonify({'results': [{'text': backend_response}]}), 200
@ -38,14 +29,20 @@ class VLLMBackend(LLMBackend):
top_k = parameters.get('top_k', self._default_params['top_k'])
if top_k <= 0:
top_k = -1
# TODO: support more params
sampling_params = SamplingParams(
temperature=parameters.get('temperature', self._default_params['temperature']),
top_p=parameters.get('top_p', self._default_params['top_p']),
top_k=top_k,
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
stop=parameters.get('stopping_strings', self._default_params['stop']),
stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))),
ignore_eos=parameters.get('ban_eos_token', False),
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']),
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']),
early_stopping=parameters.get('early_stopping', self._default_params['early_stopping'])
)
except ValueError as e:
return None, str(e).strip('.')

52
llm_server/logging.py Normal file
View File

@ -0,0 +1,52 @@
import logging
import coloredlogs
from llm_server import opts
class LoggingInfo:
def __init__(self):
self._level = logging.INFO
self._format = opts.LOGGING_FORMAT
@property
def level(self):
return self._level
@level.setter
def level(self, value):
self._level = value
@property
def format(self):
return self._format
@format.setter
def format(self, value):
self._format = value
logging_info = LoggingInfo()
def init_logging():
"""
Set up the parent logger.
:return:
"""
logger = logging.getLogger('llm_server')
logger.setLevel(logging_info.level)
def create_logger(name):
logger = logging.getLogger('llm_server').getChild(name)
logger.setLevel(logging_info.level)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging_info.level)
formatter = logging.Formatter(logging_info.format)
handler.setFormatter(formatter)
logger.addHandler(handler)
coloredlogs.install(logger=logger, level=logging_info.level)
return logger

1
llm_server/messages.py Normal file
View File

@ -0,0 +1 @@
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'

View File

@ -1,52 +0,0 @@
import json
from datetime import datetime, timedelta
import requests
from llm_server import opts
def get_power_states():
gpu_num = 0
output = {}
while True:
url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state"
try:
response = requests.get(url, timeout=10)
if response.status_code != 200:
break
data = json.loads(response.text)
power_state_data = data['data'][0]
power_state = None
for i in range(1, len(power_state_data)):
if power_state_data[i] == 1:
power_state = data['labels'][i]
break
output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p'))
except Exception as e:
print('Failed to fetch Netdata metrics:', e)
return output
gpu_num += 1
return output
def get_gpu_wh(gpu_id: int):
chart_name = f"nvidia_smi.gpu{gpu_id}_power"
now = datetime.now()
one_hour_ago = now - timedelta(hours=1)
num_seconds = int((now - one_hour_ago).total_seconds())
params = {
"chart": chart_name,
"after": int(one_hour_ago.timestamp()),
"before": int(now.timestamp()),
"points": num_seconds,
"group": "second",
"format": "json",
"options": "absolute|jsonwrap"
}
response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10)
data = json.loads(response.text)
total_power_usage_watts = sum(point[1] for point in data['result']['data'])
# total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1)
total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3)
return total_power_usage_kwh

View File

@ -1,12 +1,11 @@
# Read-only global variables
# Uppercase variables are read-only globals.
# Lowercase variables are ones that are set on startup and are never changed.
# TODO: rewrite the config system so I don't have to add every single config default here
running_model = 'ERROR'
concurrent_gens = 3
mode = 'oobabooga'
backend_url = None
context_size = 5555
frontend_api_mode = 'ooba'
max_new_tokens = 500
auth_required = False
log_prompts = False
@ -33,7 +32,15 @@ openai_expose_our_model = False
openai_force_no_hashes = True
include_system_tokens_in_stats = True
openai_moderation_scan_last_n = 5
openai_moderation_workers = 10
openai_org_name = 'OpenAI'
openai_silent_trim = False
openai_moderation_enabled = True
cluster = {}
show_backends = True
background_homepage_cacher = True
openai_moderation_timeout = 5
prioritize_by_size = False
cluster_workers = 0
redis_stream_timeout = 25000
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"

View File

@ -1,21 +1,9 @@
import sys
from redis import Redis
from llm_server.routes.cache import redis
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.custom_redis import redis
def server_startup(s):
if not redis.get('daemon_started', bool):
if not redis.get('daemon_started', dtype=bool):
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
sys.exit(1)
# Flush the RedisPriorityQueue database.
queue_redis = Redis(host='localhost', port=6379, db=15)
for key in queue_redis.scan_iter('*'):
queue_redis.delete(key)
# Cache the initial stats
print('Loading backend stats...')
generate_stats()

View File

@ -1,11 +1,18 @@
from llm_server import opts
from llm_server.routes.cache import redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
def format_sillytavern_err(msg: str, level: str = 'info'):
http_host = redis.get('http_host', str)
def format_sillytavern_err(msg: str, backend_url: str = None, error_type: str = 'info'):
if backend_url:
cluster_backend_hash = cluster_config.get_backend(backend_url)['hash']
else:
cluster_backend_hash = 'none'
http_host = redis.get('http_host', dtype=str)
return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
-> {level.upper()} <-
-> {error_type.upper()} <-
{msg}
```
```
BACKEND: {cluster_backend_hash}
```"""

View File

@ -100,4 +100,4 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas
j = json.loads(str(data))
return True, j
except Exception as e:
return False, e
return False, e

View File

@ -0,0 +1,15 @@
def estimate_model_size(config: dict):
"""
Estimate the size of a model from its config. No idea if this is correct,
but it allows us to compare models.
:param config:
:return:
"""
vocab_size = config.get('vocab_size')
hidden_size = config.get('hidden_size')
num_hidden_layers = config.get('num_hidden_layers')
intermediate_size = config.get('intermediate_size')
if vocab_size and hidden_size and num_hidden_layers and intermediate_size:
total_params = (vocab_size * hidden_size) + (num_hidden_layers * ((hidden_size * intermediate_size * 4) + (hidden_size * hidden_size * 3)))
return int(total_params / 1e9)
return 0

View File

@ -3,8 +3,8 @@ from typing import Tuple
import flask
from flask import jsonify, request
from llm_server import opts
from llm_server.database.database import log_prompt
from llm_server import messages, opts
from llm_server.database.log_to_db import log_to_db
from llm_server.routes.helpers.client import format_sillytavern_err
from llm_server.routes.request_handler import RequestHandler
@ -13,8 +13,11 @@ class OobaRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def handle_request(self):
def handle_request(self, return_ok: bool = True):
assert not self.used
if self.offline:
print('This backend is offline:', messages.BACKEND_OFFLINE)
return self.handle_error(messages.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request()
if not request_valid:
@ -25,14 +28,19 @@ class OobaRequestHandler(RequestHandler):
llm_request = {**self.parameters, 'prompt': prompt}
_, backend_response = self.generate_response(llm_request)
return backend_response
if return_ok:
# Always return 200 so ST displays our error messages
return backend_response[0], 200
else:
# The OpenAI route needs to detect 429 errors.
return backend_response
def handle_ratelimited(self, do_log: bool = True):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg)
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return backend_response[0], 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
@ -40,7 +48,7 @@ class OobaRequestHandler(RequestHandler):
# TODO: how to format this
response_msg = error_msg
else:
response_msg = format_sillytavern_err(error_msg, error_type)
response_msg = format_sillytavern_err(error_msg, error_type=error_type, backend_url=self.backend_url)
return jsonify({
'results': [{'text': response_msg}]

View File

@ -5,9 +5,11 @@ from ..server_error import handle_server_error
from ... import opts
openai_bp = Blueprint('openai/v1/', __name__)
openai_model_bp = Blueprint('openai/', __name__)
@openai_bp.before_request
@openai_model_bp.before_request
def before_oai_request():
if not opts.enable_openi_compatible_backend:
return 'The OpenAI-compatible backend is disabled.', 401
@ -15,8 +17,22 @@ def before_oai_request():
@openai_bp.errorhandler(500)
@openai_model_bp.errorhandler(500)
def handle_error(e):
return handle_server_error(e)
"""
Found Codes:
"auth_subrequest_error"
"""
print('OAI returning error:', e)
return jsonify({
"error": {
"message": "Internal server error",
"type": "auth_subrequest_error",
"param": None,
"code": "internal_error"
}
}), 500
from .models import openai_list_models

View File

@ -1,113 +1,175 @@
import json
import threading
import time
import traceback
import ujson
from flask import Response, jsonify, request
from redis import Redis
from . import openai_bp
from ..cache import redis
from llm_server.custom_redis import redis
from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler
from ..queue import priority_queue
from ... import opts
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
from ...llm.vllm import tokenize
from ...database.log_to_db import log_to_db
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
# TODO: add rate-limit headers?
@openai_bp.route('/chat/completions', methods=['POST'])
def openai_chat_completions():
@openai_model_bp.route('/<model_name>/v1/chat/completions', methods=['POST'])
def openai_chat_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
else:
handler = OpenAIRequestHandler(request, request_json_body)
if request_json_body.get('stream'):
if not opts.enable_streaming:
# TODO: return a proper OAI error message
return 'disabled', 401
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline:
return return_invalid_model_err(model_name)
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
response_status_code = 0
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception('TODO: simulate OAI here')
else:
handler.prompt = transform_messages_to_prompt(request_json_body['messages'])
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
try:
response = generator(msg_to_backend)
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
def generate():
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
data = {
"id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
return Response(generate(), mimetype='text/event-stream')
except:
# TODO: simulate OAI here
raise Exception
else:
if not request_json_body.get('stream'):
try:
return handler.handle_request()
except Exception:
traceback.print_exc()
return 'Internal server error', 500
else:
if not opts.enable_streaming:
return 'Streaming disabled', 403
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
handler.parameters, e = handler.get_parameters()
handler.request_json_body = {
'messages': handler.request_json_body['messages'],
'model': handler.request_json_body['model'],
**handler.parameters
}
if opts.openai_silent_trim:
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
else:
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
if not handler.prompt:
# Prevent issues on the backend.
return 'Invalid prompt', 400
# Need to set the prompt in the JSON body since that's what the inference worker expects.
handler.request_json_body['prompt'] = handler.prompt
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
429,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
try:
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
# Need to do this before we enter generate() since we want to be able to
# return a 408 if necessary.
_, stream_name, error_msg = event.wait()
if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None # set to null so that the Finally ignores it.
return 'Request Timeout', 408
def generate():
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0'
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data'])
if data['error']:
# Not printing error since we can just check the daemon log.
print('OAI streaming encountered error')
yield 'data: [DONE]\n\n'
return
elif data['new']:
response = {
"id": f"chatcmpl-{oai_string}",
"object": "chat.completion.chunk",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": data['new']
},
"finish_reason": None
}
]
}
generated_text = generated_text + data['new']
yield f'data: {json.dumps(response)}\n\n'
elif data['completed']:
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
200,
r_url,
handler.backend_url,
)
return
except GeneratorExit:
return
except Exception:
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
if event:
redis.publish(f'notifications:{event.event_id}', 'canceled')
if stream_name:
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')
except Exception:
traceback.print_exc()
return 'INTERNAL SERVER', 500

View File

@ -1,38 +1,68 @@
import time
import traceback
from flask import jsonify, make_response, request
import simplejson as json
import ujson
from flask import Response, jsonify, request
from redis import Redis
from . import openai_bp
from ..cache import redis
from ..helpers.client import format_sillytavern_err
from llm_server.custom_redis import redis
from . import openai_bp, openai_model_bp
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import priority_queue
from ... import opts
from ...database.log_to_db import log_to_db
from ...llm import get_token_count
from ...llm.openai.transform import build_openai_response, generate_oai_string
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
# TODO: add rate-limit headers?
@openai_bp.route('/completions', methods=['POST'])
def openai_completions():
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
def openai_completions(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
try:
response, status_code = OobaRequestHandler(request).handle_request()
if status_code != 200:
return status_code
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline:
return return_invalid_model_err(model_name)
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim:
handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
# The handle_request() call below will load the prompt so we don't have
# to do anything else here.
pass
handler.request_json_body['prompt'] = handler.prompt
if not request_json_body.get('stream'):
invalid_oai_err_msg = validate_oai(request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
response, status_code = handler.handle_request(return_ok=False)
if status_code == 429:
return handler.handle_ratelimited()
output = response.json['results'][0]['text']
# TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'])
response_tokens = get_token_count(output)
running_model = redis.get('running_model', str, 'ERROR')
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
response_tokens = get_token_count(output, handler.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
response = jsonify({
"id": f"cmpl-{generate_oai_string(30)}",
"object": "text_completion",
"created": int(time.time()),
@ -42,7 +72,7 @@ def openai_completions():
"text": output,
"index": 0,
"logprobs": None,
"finish_reason": None
"finish_reason": "stop"
}
],
"usage": {
@ -50,12 +80,141 @@ def openai_completions():
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
})
stats = redis.get('proxy_stats', dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
except Exception:
traceback.print_exc()
return 'Internal Server Error', 500
# TODO:
# stats = redis.get('proxy_stats', dtype=dict)
# if stats:
# response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response, 200
else:
if not opts.enable_streaming:
return 'Streaming disabled', 403
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
handler.parameters, _ = handler.get_parameters()
handler.request_json_body = {
'prompt': handler.request_json_body['prompt'],
'model': handler.request_json_body['model'],
**handler.parameters
}
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
if not handler.prompt:
# Prevent issues on the backend.
return 'Invalid prompt', 400
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
None,
None,
handler.parameters,
request.headers,
429,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
try:
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
_, stream_name, error_msg = event.wait()
if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None
return 'Request Timeout', 408
def generate():
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0'
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data'])
if data['error']:
print('OAI streaming encountered error')
yield 'data: [DONE]\n\n'
return
elif data['new']:
response = {
"id": f"cmpl-{oai_string}",
"object": "text_completion",
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": data['new']
},
"finish_reason": None
}
]
}
generated_text = generated_text + data['new']
yield f'data: {json.dumps(response)}\n\n'
elif data['completed']:
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(
handler.client_ip,
handler.token,
handler.prompt,
generated_text,
elapsed_time,
handler.parameters,
r_headers,
200,
r_url,
handler.backend_url,
)
return
except GeneratorExit:
# This should be triggered if a client disconnects early.
return
except Exception:
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
if event:
redis.publish(f'notifications:{event.event_id}', 'canceled')
if stream_name:
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')
except Exception:
traceback.print_exc()
return 'INTERNAL SERVER', 500

View File

@ -1,7 +1,7 @@
from flask import Response
from . import openai_bp
from ..cache import flask_cache
from llm_server.custom_redis import flask_cache
from ... import opts

View File

@ -3,59 +3,58 @@ import traceback
import requests
from flask import jsonify
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, flask_cache, redis
from ..stats import server_start_time
from ... import opts
from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
from ...helpers import jsonify_pretty
from ...llm.info import get_running_model
from ...llm.openai.transform import generate_oai_string
@openai_bp.route('/models', methods=['GET'])
@flask_cache.cached(timeout=60, query_string=True)
def openai_list_models():
model, error = get_running_model()
if not model:
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
if not model_name:
response = jsonify({
'code': 502,
'msg': 'failed to reach backend',
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
running_model = redis.get('running_model', str, 'ERROR')
running_model = redis.get('running_model', 'ERROR', dtype=str)
oai = fetch_openai_models()
r = []
r = {
"object": "list",
"data": oai
}
# TODO: verify this works
if opts.openai_expose_our_model:
r = [{
"object": "list",
"data": [
r["data"].insert(0, {
"id": running_model,
"object": "model",
"created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name,
"permission": [
{
"id": running_model,
"object": "model",
"object": "model_permission",
"created": int(server_start_time.timestamp()),
"owned_by": opts.llm_middleware_name,
"permission": [
{
"id": running_model,
"object": "model_permission",
"created": int(server_start_time.timestamp()),
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": True,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False
}
],
"root": None,
"parent": None
"allow_create_engine": False,
"allow_sampling": False,
"allow_logprobs": False,
"allow_search_indices": False,
"allow_view": True,
"allow_fine_tuning": False,
"organization": "*",
"group": None,
"is_blocking": False
}
]
}]
response = jsonify_pretty(r + oai), 200
],
"root": None,
"parent": None
})
response = jsonify_pretty(r), 200
return response
@ -64,7 +63,14 @@ def fetch_openai_models():
if opts.openai_api_key:
try:
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
return response.json()['data']
j = response.json()['data']
# The "modelperm" string appears to be user-specific, so we'll
# randomize it just to be safe.
for model in range(len(j)):
for p in range(len(j[model]['permission'])):
j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}'
return j
except:
traceback.print_exc()
return []

View File

@ -1,7 +1,7 @@
from flask import jsonify
from . import openai_bp
from ..cache import ONE_MONTH_SECONDS, flask_cache
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache
from ...llm.openai.transform import generate_oai_string
from ..stats import server_start_time
@ -17,7 +17,7 @@ def openai_organizations():
"id": f"org-{generate_oai_string(24)}",
"created": int(server_start_time.timestamp()),
"title": "Personal",
"name": "user-abcdefghijklmnopqrstuvwx",
"name": f"user-{generate_oai_string(24)}",
"description": "Personal org for bobjoe@0.0.0.0",
"personal": True,
"is_default": True,

View File

@ -1,14 +1,21 @@
import json
import re
import time
import traceback
from typing import Tuple
from uuid import uuid4
import flask
from flask import jsonify
from flask import Response, jsonify, make_response
from llm_server import opts
from llm_server.cluster.backend import get_model_choices
from llm_server.custom_redis import redis
from llm_server.database.database import is_api_key_moderated
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
from llm_server.database.log_to_db import log_to_db
from llm_server.llm import get_token_count
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
from llm_server.routes.request_handler import RequestHandler
from llm_server.workers.moderator import add_moderation_task, get_results
@ -20,20 +27,37 @@ class OpenAIRequestHandler(RequestHandler):
def handle_request(self) -> Tuple[flask.Response, int]:
assert not self.used
if self.offline:
msg = return_invalid_model_err(self.selected_model)
print('OAI Offline:', msg)
return self.handle_error(msg)
if opts.openai_silent_trim:
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
else:
oai_messages = self.request.json['messages']
self.prompt = transform_messages_to_prompt(oai_messages)
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
request_valid, invalid_response = self.validate_request()
if not request_valid:
return invalid_response
if opts.openai_api_key and is_api_key_moderated(self.token):
if not self.prompt:
# TODO: format this as an openai error message
return Response('Invalid prompt'), 400
# TODO: support Ooba backend
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
invalid_oai_err_msg = validate_oai(self.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
try:
# Gather the last message from the user and all preceeding system messages
# Gather the last message from the user and all preceding system messages
msg_l = self.request.json['messages'].copy()
msg_l.reverse()
tag = uuid4()
@ -49,33 +73,40 @@ class OpenAIRequestHandler(RequestHandler):
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
except Exception as e:
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
print(traceback.format_exc())
# Reconstruct the request JSON with the validated parameters and prompt.
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
if opts.openai_force_no_hashes:
self.parameters['stop'].append('### ')
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
self.request_json_body['top_p'] = 0.01
traceback.print_exc()
llm_request = {**self.parameters, 'prompt': self.prompt}
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
model = self.request_json_body.get('model')
if success:
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
else:
return backend_response, backend_response_status_code
def handle_ratelimited(self, do_log: bool = True):
# TODO: return a simulated OpenAI error message
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.
return 'Ratelimited', 429
model_choices, default_model = get_model_choices()
default_model_info = model_choices[default_model]
w = int(default_model_info['estimated_wait']) if default_model_info['estimated_wait'] > 0 else 2
response = jsonify({
"error": {
"message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.",
"type": "rate_limit_exceeded",
"param": None,
"code": None
}
})
response.headers['x-ratelimit-limit-requests'] = '2'
response.headers['x-ratelimit-remaining-requests'] = '0'
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
if do_log:
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
# TODO: return a simulated OpenAI error message
print('OAI Error:', error_msg)
return jsonify({
"error": {
"message": "Invalid request, check your parameters and try again.",
@ -84,3 +115,51 @@ class OpenAIRequestHandler(RequestHandler):
"code": None
}
}), 400
def build_openai_response(self, prompt, response, model=None):
# Seperate the user's prompt from the context
x = prompt.split('### USER:')
if len(x) > 1:
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
prompt_tokens = get_token_count(prompt, self.backend_url)
response_tokens = get_token_count(response, self.backend_url)
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
"object": "chat.completion",
"created": int(time.time()),
"model": running_model if opts.openai_expose_our_model else model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": response,
},
"logprobs": None,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": response_tokens,
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
return response
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
self.parameters, parameters_invalid_msg = self.get_parameters()
if not self.parameters:
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
return False, (Response('Invalid request, check your parameters and try again.'), 400)
invalid_oai_err_msg = validate_oai(self.parameters)
if invalid_oai_err_msg:
return False, invalid_oai_err_msg
# self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
# If the parameters were invalid, let the superclass deal with it.
return super().validate_request(prompt, do_log)

View File

@ -1,12 +1,15 @@
import json
import pickle
import time
from typing import Tuple
from uuid import uuid4
import ujson as json
from redis import Redis
from llm_server import opts
from llm_server.routes.cache import redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis
from llm_server.database.database import get_token_ratelimit
def increment_ip_count(client_ip: str, redis_key):
@ -20,24 +23,30 @@ def decrement_ip_count(client_ip: str, redis_key):
class RedisPriorityQueue:
def __init__(self):
self.redis = Redis(host='localhost', port=6379, db=15)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe('events')
"""
A queue for a specific backend.
"""
def put(self, item, priority):
event = DataEvent()
def __init__(self, name, db: int = 12):
self.name = name
self.redis = RedisCustom(name, db=db)
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
# TODO: remove this when we're sure nothing strange is happening
assert item is not None
assert priority is not None
assert selected_model is not None
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count:
ip_count = int(ip_count)
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
ip_count = self.get_ip_request_count(item[1])
_, simultaneous_ip = get_token_ratelimit(item[2])
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
return None # reject the request
self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority})
self.increment_ip_count(item[1], 'queued_ip_count')
timestamp = time.time()
event = DataEvent()
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
return event
def get(self):
@ -45,31 +54,59 @@ class RedisPriorityQueue:
data = self.redis.zpopmin('queue')
if data:
item = json.loads(data[0][0])
client_ip = item[0][1]
self.decrement_ip_count(client_ip, 'queued_ip_count')
return item
time.sleep(0.1) # wait for something to be added to the queue
def increment_ip_count(self, client_ip: str, redis_key):
self.redis.hincrby(redis_key, client_ip, 1)
def decrement_ip_count(self, client_ip: str, redis_key):
new_count = self.redis.hincrby(redis_key, client_ip, -1)
if new_count <= 0:
self.redis.hdel(redis_key, client_ip)
def __len__(self):
return self.redis.zcard('queue')
def get_queued_ip_count(self, client_ip: str):
q = self.redis.hget('queued_ip_count', client_ip)
if not q:
return 0
return 0
def get_ip_request_count(self, client_ip: str):
"""
Get the number of requests in the queue from a specific IP.
This is a bit inefficient since we iterate over the entire queue, but
keeps the queue as a single point of truth instead of tracking a separate hashed
set which can get confusing.
If we run into slowdowns in the future, we should go back to the hashed set approach.
:param client_ip:
:return:
"""
start_time = time.time()
items = self.redis.zrange('queue', 0, -1)
count = 0
for item in items:
item_data = json.loads(item)
if item_data[0][1] == client_ip:
count += 1
elapsed_time = time.time() - start_time
if elapsed_time > 0.5:
raise Exception(f"!!! get_ip_request_count took {elapsed_time} seconds to execute !!!")
return count
def flush(self):
self.redis.flush()
def items(self):
return self.redis.zrange('queue', 0, -1)
def cleanup(self):
now = time.time()
for item in self.items():
item_data = json.loads(item)
timestamp = item_data[-2]
if now - timestamp > opts.backend_generate_request_timeout:
self.redis.zrem('queue', 0, item)
event_id = item_data[1]
event = DataEvent(event_id)
event.set((False, None, 'closed'))
print('Removed timed-out item from queue:', event_id)
class DataEvent:
def __init__(self, event_id=None):
"""
Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol).
"""
def __init__(self, event_id: str = None):
self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub()
@ -84,15 +121,89 @@ class DataEvent:
return pickle.loads(item['data'])
priority_queue = RedisPriorityQueue()
def update_active_workers(key: str, operation: str):
if operation == 'incr':
redis.incr(f'active_gen_workers:{key}')
elif operation == 'decr':
redis.decr(f'active_gen_workers:{key}')
if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0:
redis.set(f'active_gen_workers:{key}', 0)
def incr_active_workers():
redis.incr('active_gen_workers')
def incr_active_workers(selected_model: str, backend_url: str):
update_active_workers(selected_model, 'incr')
update_active_workers(backend_url, 'incr')
def decr_active_workers():
redis.decr('active_gen_workers')
new_count = redis.get('active_gen_workers', int, 0)
if new_count < 0:
redis.set('active_gen_workers', 0)
def decr_active_workers(selected_model: str, backend_url: str):
update_active_workers(selected_model, 'decr')
update_active_workers(backend_url, 'decr')
class PriorityQueue:
"""
Helper class to wrangler all the different queues.
"""
def __init__(self, backends: set = None):
"""
Only have to load the backends once.
:param backends:
"""
self.redis = Redis(host='localhost', port=6379, db=9)
if backends:
for item in backends:
self.redis.sadd('backends', item)
def get_backends(self):
return {x.decode('utf-8') for x in self.redis.smembers('backends')}
def get_queued_ip_count(self, client_ip: str):
count = 0
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
count += queue.get_ip_request_count(client_ip)
return count
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
queue = RedisPriorityQueue(backend_url)
return queue.put(item, priority, selected_model, do_stream)
def activity(self):
lines = []
status_redis = RedisCustom('worker_status')
for worker in status_redis.keys():
lines.append((worker, status_redis.getp(worker)))
return sorted(lines)
def len(self, model_name):
count = 0
backends_with_models = set()
for k in self.get_backends():
info = cluster_config.get_backend(k)
if info.get('model') == model_name:
backends_with_models.add(k)
for backend_url in backends_with_models:
count += len(RedisPriorityQueue(backend_url))
return count
def __len__(self):
count = 0
p = set()
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
p.add((backend_url, len(queue)))
count += len(queue)
return count
def flush(self):
for k in self.redis.keys():
q = json.loads(self.redis.get(k))
q.flush()
self.redis.set(k, json.dumps(q))
def flush_db(self):
self.redis.flushdb()
priority_queue = PriorityQueue()

View File

@ -5,23 +5,22 @@ import flask
from flask import Response, request
from llm_server import opts
from llm_server.database.conn import database
from llm_server.database.database import log_prompt
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
from llm_server.custom_redis import redis
from llm_server.database.database import get_token_ratelimit
from llm_server.database.log_to_db import log_to_db
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token
from llm_server.routes.cache import redis
from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue
DEFAULT_PRIORITY = 9999
class RequestHandler:
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
self.request = incoming_request
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
# self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
# Routes need to validate it, here we just load it
if incoming_json:
@ -34,11 +33,38 @@ class RequestHandler:
self.start_time = time.time()
self.client_ip = self.get_client_ip()
self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
self.backend = get_backend()
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
self.parameters = None
self.used = False
redis.zadd('recent_prompters', {self.client_ip: time.time()})
# This is null by default since most handlers need to transform the prompt in a specific way.
self.prompt = None
self.selected_model = selected_model
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
# Debug stuff
# if not self.cluster_backend_info.get('mode'):
# print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
# if not self.cluster_backend_info.get('model'):
# print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
# if not self.cluster_backend_info.get('model_config'):
# print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
self.offline = True
else:
self.offline = False
self.selected_model = self.cluster_backend_info['model']
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
if self.token and not self.token.startswith('SYSTEM__'):
# "recent_prompters" is only used for stats.
redis.zadd('recent_prompters', {self.client_ip: time.time()})
def check_online(self) -> bool:
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
return self.cluster_backend_info['online']
def get_auth_token(self):
if self.request_json_body.get('X-API-KEY'):
@ -49,6 +75,8 @@ class RequestHandler:
return parse_token(self.request.headers['Authorization'])
def get_client_ip(self):
if self.request.headers.get('Llm-Connecting-Ip'):
return self.request.headers['Llm-Connecting-Ip']
if self.request.headers.get('X-Connecting-IP'):
return self.request.headers.get('X-Connecting-IP')
elif self.request.headers.get('Cf-Connecting-Ip'):
@ -58,26 +86,7 @@ class RequestHandler:
else:
return self.request.remote_addr
def get_token_ratelimit(self):
priority = DEFAULT_PRIORITY
simultaneous_ip = opts.simultaneous_requests_per_ip
if self.token:
cursor = database.cursor()
try:
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
result = cursor.fetchone()
if result:
priority, simultaneous_ip = result
if simultaneous_ip is None:
# No ratelimit for this token if null
simultaneous_ip = 999999999
finally:
cursor.close()
return priority, simultaneous_ip
def get_parameters(self):
if self.request_json_body.get('max_tokens'):
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
return parameters, parameters_invalid_msg
@ -119,7 +128,7 @@ class RequestHandler:
backend_response = self.handle_error(combined_error_message, 'Validation Error')
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
return False, backend_response
return True, (None, 0)
@ -131,14 +140,18 @@ class RequestHandler:
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid:
return (False, None, None, 0), invalid_response
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
event = priority_queue.put(self.backend_url, (llm_request, self.client_ip, self.token, self.parameters), self.token_priority, self.selected_model)
else:
event = None
if not event:
return (False, None, None, 0), self.handle_ratelimited()
# TODO: add wait timeout
success, response, error_msg = event.wait()
if error_msg == 'closed':
return (False, None, None, 0), (self.handle_error('Request Timeout')[0], 408)
end_time = time.time()
elapsed_time = end_time - self.start_time
@ -160,7 +173,17 @@ class RequestHandler:
else:
error_msg = error_msg.strip('.') + '.'
backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
log_to_db(ip=self.client_ip,
token=self.token,
prompt=prompt,
response=backend_response[0].data.decode('utf-8'),
gen_time=None,
parameters=self.parameters,
headers=dict(self.request.headers),
backend_response_code=response_status_code,
request_url=self.request.url,
backend_url=self.backend_url,
is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
@ -180,7 +203,7 @@ class RequestHandler:
if return_json_err:
error_msg = 'The backend did not return valid JSON.'
backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
@ -189,22 +212,29 @@ class RequestHandler:
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def is_client_ratelimited(self) -> bool:
if self.token_priority == 0:
return False
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
x = redis.hget('processing_ips', self.client_ip)
if x:
processing_ip = int(x)
else:
processing_ip = 0
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
return False
else:
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
return True
else:
return False
def handle_request(self) -> Tuple[flask.Response, int]:
# Must include this in your child.
# if self.used:
# raise Exception('Can only use a RequestHandler object once.')
# assert not self.used
# if self.offline:
# msg = f'{self.selected_model} is not a valid model choice.'
# print(msg)
# return format_sillytavern_err(msg)
raise NotImplementedError
def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]:
@ -214,11 +244,11 @@ class RequestHandler:
raise NotImplementedError
def get_backend():
if opts.mode == 'oobabooga':
return OobaboogaBackend()
elif opts.mode == 'vllm':
return VLLMBackend()
def get_backend_handler(mode, backend_url: str):
if mode == 'oobabooga':
return OobaboogaBackend(backend_url)
elif mode == 'vllm':
return VLLMBackend(backend_url)
else:
raise Exception

View File

@ -1,3 +1,3 @@
def handle_server_error(e):
print(e)
return {'error': True}, 500
print('Internal Error:', e)
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500

View File

@ -1,33 +1,11 @@
from datetime import datetime
from llm_server.routes.cache import redis
# proompters_5_min = 0
# concurrent_semaphore = Semaphore(concurrent_gens)
from llm_server.custom_redis import redis
from llm_server.helpers import round_up_base
server_start_time = datetime.now()
# TODO: do I need this?
# def elapsed_times_cleanup():
# global wait_in_queue_elapsed
# while True:
# current_time = time.time()
# with wait_in_queue_elapsed_lock:
# global wait_in_queue_elapsed
# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
# time.sleep(1)
def calculate_avg_gen_time():
# Get the average generation time from Redis
average_generation_time = redis.get('average_generation_time')
if average_generation_time is None:
return 0
else:
return float(average_generation_time)
def get_total_proompts():
count = redis.get('proompts')
if count is None:
@ -37,10 +15,27 @@ def get_total_proompts():
return count
def get_active_gen_workers():
active_gen_workers = redis.get('active_gen_workers')
if active_gen_workers is None:
count = 0
def get_active_gen_workers_model(selected_model: str = None):
return redis.get(f'active_gen_workers:{selected_model}', dtype=int, default=0)
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
if active_gen_workers < concurrent_gens:
return 0
elif active_gen_workers >= concurrent_gens:
# Calculate how long it will take to complete the currently running gens and the queued requests.
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
# Add gen_time_calc to the time to account for the currently running generations.
# This assumes that all active workers will finish at the same time, which is unlikely.
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
elif proompters_in_queue == 0 and active_gen_workers == 0:
# No queue, no workers
return 0
else:
count = int(active_gen_workers)
return count
# No queue
return gen_time_calc

View File

@ -3,18 +3,18 @@ import traceback
from flask import jsonify, request
from . import bp
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
@bp.route('/generate', methods=['POST'])
def generate():
@bp.route('/v1/generate', methods=['POST'])
@bp.route('/<model_name>/v1/generate', methods=['POST'])
def generate(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
handler = OobaRequestHandler(request)
handler = OobaRequestHandler(request, selected_model=model_name)
try:
return handler.handle_request()
except Exception:

View File

@ -2,83 +2,30 @@ import time
from datetime import datetime
from llm_server import opts
from llm_server.cluster.backend import get_model_choices
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.database.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort, round_up_base
from llm_server.llm.info import get_running_model
from llm_server.netdata import get_power_states
from llm_server.routes.cache import redis
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
from llm_server.helpers import deep_sort
from llm_server.routes.stats import get_total_proompts, server_start_time
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
if active_gen_workers < concurrent_gens:
return 0
elif active_gen_workers >= concurrent_gens:
# Calculate how long it will take to complete the currently running gens and the queued requests.
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
# Add gen_time_calc to the time to account for the currently running generations.
# This assumes that all active workers will finish at the same time, which is unlikely.
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
elif proompters_in_queue == 0 and active_gen_workers == 0:
# No queue, no workers
return 0
else:
# No queue
return gen_time_calc
# TODO: have routes/__init__.py point to the latest API version generate_stats()
def generate_stats(regen: bool = False):
if not regen:
c = redis.get('proxy_stats', dict)
c = redis.getp('proxy_stats')
if c:
return c
model_name, error = get_running_model() # will return False when the fetch fails
if isinstance(model_name, bool):
online = False
else:
online = True
redis.set('running_model', model_name)
model_choices, default_model = get_model_choices(regen=True)
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
# if len(t) == 0:
# estimated_wait = 0
# else:
# waits = [elapsed for end, elapsed in t]
# estimated_wait = int(sum(waits) / len(waits))
active_gen_workers = get_active_gen_workers()
proompters_in_queue = len(priority_queue)
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
if opts.netdata_root:
netdata_stats = {}
power_states = get_power_states()
for gpu, power_state in power_states.items():
netdata_stats[gpu] = {
'power_state': power_state,
# 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
}
else:
netdata_stats = {}
base_client_api = redis.get('base_client_api', str)
base_client_api = redis.get('base_client_api', dtype=str)
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
output = {
'models': {
'choices': model_choices,
'default': default_model,
},
'stats': {
'proompters': {
'5_min': proompters_5_min,
@ -86,39 +33,49 @@ def generate_stats(regen: bool = False):
},
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
'average_generation_elapsed_sec': int(average_generation_time),
# 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
'num_backends': len(cluster_config.all()) if opts.show_backends else None,
},
'online': online,
'endpoints': {
'blocking': f'https://{base_client_api}',
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
},
'queue': {
'processing': active_gen_workers,
'queued': proompters_in_queue,
'estimated_wait_sec': int(estimated_wait_sec),
},
'timestamp': int(time.time()),
'config': {
'gatekeeper': 'none' if opts.auth_required is False else 'token',
'context_size': opts.context_size,
'concurrent': opts.concurrent_gens,
'model': opts.manual_model_name if opts.manual_model_name else model_name,
'mode': opts.mode,
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
'api_mode': opts.frontend_api_mode
},
'keys': {
'openaiKeys': '',
'anthropicKeys': '',
},
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
'nvidia': netdata_stats
'backends': {},
'online': len(model_choices) > 0
}
# TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself
if opts.show_backends:
for backend_url, v in cluster_config.all().items():
backend_info = cluster_config.get_backend(backend_url)
if not backend_info['online']:
continue
backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None
output['backends'][backend_info['hash']] = {
'uptime': backend_uptime,
'max_tokens': backend_info['model_config'].get('max_position_embeddings', -1),
'model': backend_info['model'],
'mode': backend_info['mode'],
'nvidia': backend_info['nvidia'],
'priority': backend_info['priority'],
}
result = deep_sort(output)
# It may take a bit to get the base client API, so don't cache until then.
if base_client_api:
redis.set_dict('proxy_stats', result) # Cache with no expiry
redis.setp('proxy_stats', result)
return result

View File

@ -1,186 +1,200 @@
import json
import threading
import time
import traceback
from typing import Union
import ujson
from flask import request
from redis import Redis
from ..cache import redis
from . import bp
from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ..queue import priority_queue
from ... import opts
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.vllm import tokenize
from ...stream import sock
from ...custom_redis import redis
from ...database.log_to_db import log_to_db
from ...sock import sock
# TODO: have workers process streaming requests
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
# We solve this by splitting the routes
@sock.route('/api/v1/stream')
def stream(ws):
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close()
log_in_bg(quitting_err_msg, is_error=True)
@bp.route('/v1/stream')
@bp.route('/<model_name>/v1/stream')
def stream(model_name=None):
return 'This is a websocket endpoint.', 400
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
def background_task_exception():
generated_tokens = tokenize(generated_text_bg)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error)
@sock.route('/v1/stream', bp=bp)
def stream_without_model(ws):
do_stream(ws, model_name=None)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
if not opts.enable_streaming:
return 'Streaming is disabled', 401
@sock.route('/<model_name>/v1/stream', bp=bp)
def stream_with_model(ws, model_name=None):
do_stream(ws, model_name)
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else:
if opts.mode != 'vllm':
# TODO: implement other backends
raise NotImplementedError
auth_failure = require_api_key(request_json_body)
if auth_failure:
return auth_failure
def do_stream(ws, model_name):
event_id = None
try:
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
'event': 'text_stream',
'message_num': 0,
'text': quitting_err_msg
}))
ws.send(json.dumps({
'event': 'stream_end',
'message_num': 1
}))
ws.close()
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=quitting_err_msg,
gen_time=None,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url,
response_tokens=None,
is_error=True
)
handler = OobaRequestHandler(request, request_json_body)
generated_text = ''
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
if not opts.enable_streaming:
return 'Streaming disabled', 403
err_msg = None
if handler.is_client_ratelimited():
r, _ = handler.handle_ratelimited(do_log=False)
err_msg = r.json['results'][0]['text']
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
while ws.connected:
message = ws.receive()
request_valid_json, request_json_body = validate_json(message)
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else:
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
if err_msg:
send_err_and_quit(err_msg)
return
# We have to do auth ourselves since the details are sent in the message.
auth_failure = require_api_key(request_json_body)
if auth_failure:
return auth_failure
llm_request = {
**handler.parameters,
'prompt': input_prompt,
'stream': True,
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority)
if not event:
r, _ = handler.handle_ratelimited()
err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg)
return
try:
response = generator(llm_request)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
if handler.offline:
msg = f'{handler.selected_model} is not a valid model choice.'
print(msg)
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'message_num': 0,
'text': msg
}))
return
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
err_msg = None
if handler.is_client_ratelimited():
r, _ = handler.handle_ratelimited(do_log=False)
err_msg = r.json['results'][0]['text']
else:
# Be extra careful when getting attributes from the response object
try:
response_status_code = response.status_code
except:
response_status_code = 0
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
if not request_valid:
err_msg = invalid_response[0].json['results'][0]['text']
if err_msg:
send_err_and_quit(err_msg)
return
partial_response = b''
handler.parameters, _ = handler.get_parameters()
handler.prompt = input_prompt
handler.request_json_body = {
'prompt': handler.prompt,
**handler.parameters
}
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
try:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
except:
# The has client closed the stream.
if request:
request.close()
ws.close()
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
return
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
r = handler.handle_ratelimited()
send_err_and_quit(r[0].data)
return
event_id = event.event_id
_, stream_name, error_msg = event.wait()
if error_msg:
print('Stream failed to start streaming:', error_msg)
ws.close(reason=1014, message='Request Timeout')
return
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0' # The ID of the last entry we read.
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data:
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
return
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
data = ujson.loads(item[b'data'])
if data['error']:
print(data['error'])
send_err_and_quit('Encountered exception while streaming.')
return
elif data['new']:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': data['new']
}))
message_num += 1
partial_response = b'' # Reset the partial response
# If there is no more data, break the loop
if not chunk:
break
end_time = time.time()
elapsed_time = end_time - start_time
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
except:
traceback.print_exc()
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': generated_text
}))
if request:
request.close()
ws.close()
log_in_bg(generated_text, is_error=True, status_code=response_status_code)
return
finally:
# The worker incremented it, we'll decrement it.
decrement_ip_count(handler.client_ip, 'processing_ips')
decr_active_workers()
try:
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
except:
# The client closed the stream.
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
ws.close() # this is important if we encountered and error and exited early.
generated_text = generated_text + data['new']
elif data['completed']:
return
except:
send_err_and_quit('Encountered exception while streaming.')
traceback.print_exc()
finally:
try:
ws.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
except:
# The client closed the stream.
pass
if stream_name:
stream_redis.delete(stream_name)
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
response=generated_text,
gen_time=elapsed_time,
parameters=handler.parameters,
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.backend_url
)
finally:
if event_id:
redis.publish(f'notifications:{event_id}', 'canceled')
try:
# Must close the connection or greenlets will complain.
ws.close()
except:
pass

View File

@ -2,22 +2,16 @@ import time
from flask import jsonify, request
from llm_server.custom_redis import flask_cache
from . import bp
from ..auth import requires_auth
from ..cache import flask_cache
from ... import opts
from ...llm.info import get_running_model
from ...cluster.backend import get_backends_from_model, is_valid_model
from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
# @bp.route('/info', methods=['GET'])
# # @cache.cached(timeout=3600, query_string=True)
# def get_info():
# # requests.get()
# return 'yes'
@bp.route('/model', methods=['GET'])
def get_model():
@bp.route('/v1/model', methods=['GET'])
@bp.route('/<model_name>/v1/model', methods=['GET'])
def get_model(model_name=None):
# We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url
@ -26,24 +20,21 @@ def get_model():
if cached_response:
return cached_response
model_name, error = get_running_model()
if not model_name:
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
if not is_valid_model(model_name):
response = jsonify({
'code': 502,
'msg': 'failed to reach backend',
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
'code': 400,
'msg': 'Model does not exist.',
}), 400
else:
num_backends = len(get_backends_from_model(model_name))
response = jsonify({
'result': opts.manual_model_name if opts.manual_model_name else model_name,
'model_backend_count': num_backends,
'timestamp': int(time.time())
}), 200
flask_cache.set(cache_key, response, timeout=60)
return response
@bp.route('/backend', methods=['GET'])
@requires_auth
def get_backend():
return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200

View File

@ -1,8 +1,10 @@
from flask import jsonify
from llm_server.custom_redis import flask_cache
from . import bp
from .generate_stats import generate_stats
from ..cache import flask_cache
from ..auth import requires_auth
from ...cluster.cluster_config import cluster_config, get_backends
from ...helpers import jsonify_pretty
@ -10,3 +12,14 @@ from ...helpers import jsonify_pretty
@flask_cache.cached(timeout=5, query_string=True)
def get_stats():
return jsonify_pretty(generate_stats())
@bp.route('/backends', methods=['GET'])
@requires_auth
def get_backend():
online, offline = get_backends()
result = {}
for i in online + offline:
info = cluster_config.get_backend(i)
result[info['hash']] = info
return jsonify(result), 200

View File

@ -3,6 +3,6 @@ from flask_sock import Sock
sock = Sock()
def init_socketio(app):
def init_wssocket(app):
global sock
sock.init_app(app)

View File

@ -1,35 +0,0 @@
from threading import Thread
from .blocking import start_workers
from .main import main_background_thread
from .moderator import start_moderation_workers
from .printer import console_printer
from .recent import recent_prompters_thread
from .threads import cache_stats
from .. import opts
def start_background():
start_workers(opts.concurrent_gens)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
print('Started the main background thread.')
start_moderation_workers(opts.openai_moderation_workers)
t = Thread(target=cache_stats)
t.daemon = True
t.start()
print('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
print('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
print('Started the console printer.')

View File

@ -1,51 +0,0 @@
import threading
import time
from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.routes.cache import redis
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
def worker():
while True:
need_to_wait()
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
need_to_wait()
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers()
if not request_json_body:
# This was a dummy request from the websocket handler.
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
continue
try:
success, response, error_msg = generator(request_json_body)
event = DataEvent(event_id)
event.set((success, response, error_msg))
finally:
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers()
def start_workers(num_workers: int):
i = 0
for _ in range(num_workers):
t = threading.Thread(target=worker)
t.daemon = True
t.start()
i += 1
print(f'Started {i} inference workers.')
def need_to_wait():
# We need to check the number of active workers since the streaming endpoint may be doing something.
active_workers = redis.get('active_gen_workers', int, 0)
s = time.time()
while active_workers >= opts.concurrent_gens:
time.sleep(0.01)
e = time.time()
if e - s > 0.5:
print(f'Worker was delayed {e - s} seconds.')

View File

@ -0,0 +1,32 @@
import time
from redis import Redis
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
# NOT NEEDED
def cleaner():
r = Redis(db=8)
stream_info = {}
while True:
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
processed_streams = []
for stream in all_streams:
stream = stream.decode()
current_size = r.xlen(stream)
# If the stream is new or its size has changed, update the size and time in the dictionary
if stream not in stream_info or current_size != stream_info[stream]['size']:
stream_info[stream] = {'size': current_size, 'time': time.time()}
processed_streams.append(stream)
else:
# If the size hasn't changed for 5 minutes, delete the stream
if time.time() - stream_info[stream]['time'] >= 300:
r.delete(stream)
print(f"Stream '{stream}' deleted due to inactivity.")
del stream_info[stream]
time.sleep(60)

View File

@ -0,0 +1,160 @@
import json
import threading
import time
import traceback
from uuid import uuid4
import ujson
from redis import Redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis
from llm_server.llm.generator import generator
from llm_server.logging import create_logger
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
stream_redis = Redis(db=8)
STREAM_NAME_PREFIX = 'stream'
def check_cancellation(event, event_id):
"""
This thread checks the pub/sub channel in the background so the main process
isn't bogged down with Redis calls. Otherwise, the main process slows down to 1 token/sec.
:param event:
:param event_id:
:return:
"""
pubsub = redis.pubsub()
pubsub.subscribe(f'notifications:{event_id}')
while not event.is_set():
message = pubsub.get_message()
if message and message['data'] == b'canceled':
event.set()
time.sleep(0.5) # check every half second
def get_stream_name(name: str):
return f'{STREAM_NAME_PREFIX}:{name}'
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
logger = create_logger('inferencer')
prompt = msg_to_backend['prompt']
stream_name = get_stream_name(stream_name)
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
event = threading.Event()
threading.Thread(target=check_cancellation, args=(event, event_id)).start()
try:
response = generator(msg_to_backend, backend_url)
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
# If there is no more data, break the loop
if not chunk:
break
if event.is_set():
logger.debug('Client canceled generation')
response.close()
return
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})})
except AttributeError as e:
if str(e) == "'bool' object has no attribute 'iter_content'":
# We don't care about these errors.
logger.debug('failed to stream from backend - no response')
else:
raise
except Exception as e:
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
raise # We won't handle the exception here.
finally:
# Publish final message to Redis stream
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})})
event.set() # stop the cancellation checking thread
#
def worker(backend_url):
logger = create_logger('inferencer')
status_redis = RedisCustom('worker_status')
worker_id = str(uuid4())
status_redis.setp(str(worker_id), None)
redis_queue = RedisPriorityQueue(backend_url)
while True:
status_redis.setp(str(worker_id), 'waiting...')
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
event = DataEvent(event_id)
try:
backend_info = cluster_config.get_backend(backend_url)
except:
# This is not a critical error because it usually means that the backend is
# offline and this backend is in a state of transition from online to offline.
logger.debug(f'got an exception while getting info for backend {backend_url} - ', traceback.format_exc())
event.set((False, None, 'exception'))
continue
if not backend_info['online']:
event.set((False, None, 'canceled'))
continue
if not selected_model:
selected_model = backend_info['model']
logger.debug(f"Starting using {backend_url} and {selected_model}. Online: {backend_info['online']}. Streaming: {do_stream}")
try:
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers(selected_model, backend_url)
if do_stream:
status_redis.setp(str(worker_id), ('streaming', client_ip))
# Return the name of the stream that the slave should connect to.
event.set((True, get_stream_name(worker_id), None))
msg_to_backend = {
**parameters,
'prompt': request_json_body['prompt'],
'stream': True,
}
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
else:
# Normal inference (not streaming).
status_redis.setp(str(worker_id), ('generating', client_ip))
success, response, error_msg = generator(request_json_body, backend_url)
event.set((success, response, error_msg))
except:
logger.error(traceback.format_exc())
event.set((False, None, 'exception'))
finally:
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), None)
def start_workers(cluster: dict):
logger = create_logger('inferencer')
i = 0
for item in cluster:
for _ in range(item['concurrent_gens']):
t = threading.Thread(target=worker, args=(item['backend_url'],))
t.daemon = True
t.start()
i += 1
logger.info(f'Started {i} inference workers.')

View File

@ -0,0 +1,31 @@
import pickle
import traceback
import redis
from llm_server.database.database import do_db_log
def db_logger():
"""
We don't want the logging operation to be blocking, so we will use a background worker
to do the logging.
:return:
"""
r = redis.Redis(host='localhost', port=6379, db=3)
p = r.pubsub()
p.subscribe('database-logger')
for message in p.listen():
try:
if message['type'] == 'message':
data = pickle.loads(message['data'])
function_name = data['function']
args = data['args']
kwargs = data['kwargs']
if function_name == 'log_prompt':
do_db_log(*args, **kwargs)
except:
traceback.print_exc()

View File

@ -1,56 +0,0 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.routes.cache import redis
def main_background_thread():
redis.set('average_generation_elapsed_sec', 0)
redis.set('estimated_avg_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
while True:
# TODO: unify this
if opts.mode == 'oobabooga':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
elif opts.mode == 'vllm':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
redis.set('estimated_avg_tps', estimated_avg_tps)
time.sleep(60)

View File

@ -0,0 +1,57 @@
import time
import requests
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config, get_backends
from llm_server.custom_redis import redis
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_info
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
def main_background_thread():
while True:
online, offline = get_backends()
for backend_url in online:
backend_info = cluster_config.get_backend(backend_url)
backend_mode = backend_info['mode']
backend_info = get_info(backend_url, backend_mode)
running_model = backend_info.get('model')
if not running_model:
continue
average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode)
if average_generation_elapsed_sec: # returns None on exception
cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec)
if average_output_tokens:
cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens)
if average_generation_elapsed_sec and average_output_tokens:
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
if opts.background_homepage_cacher:
try:
base_client_api = redis.get('base_client_api', dtype=str)
r = requests.get('https://' + base_client_api, timeout=5)
except Exception as e:
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
backends = priority_queue.get_backends()
for backend_url in backends:
queue = RedisPriorityQueue(backend_url)
queue.cleanup()
time.sleep(30)
def calc_stats_for_backend(backend_url, running_model, backend_mode):
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps

View File

@ -1,10 +1,13 @@
import json
import threading
import time
import traceback
import redis as redis_redis
from llm_server import opts
from llm_server.llm.openai.moderation import check_moderation_endpoint
from llm_server.logging import create_logger
redis_moderation = redis_redis.Redis()
@ -16,36 +19,43 @@ def start_moderation_workers(num_workers):
t.daemon = True
t.start()
i += 1
print(f'Started {i} moderation workers.')
def moderation_worker():
while True:
result = redis_moderation.blpop('queue:msgs_to_check')
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
print(result)
traceback.print_exc()
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
# TODO: don't use UUID tags to identify items. Use native redis
def get_results(tag, num_tasks):
tag = str(tag) # Required for comparison with Redis results.
tag = str(tag) # Cast a UUID4 to a string.
flagged_categories = set()
num_results = 0
start_time = time.time()
while num_results < num_tasks:
result = redis_moderation.blpop('queue:flagged_categories')
result = redis_moderation.blpop(['queue:flagged_categories'], timeout=opts.openai_moderation_timeout)
if result is None:
break # Timeout occurred, break the loop.
result_tag, categories = json.loads(result[1])
if result_tag == tag:
if categories:
for item in categories:
flagged_categories.add(item)
num_results += 1
if time.time() - start_time > opts.openai_moderation_timeout:
logger.warning('Timed out waiting for result from moderator')
break
return list(flagged_categories)
def moderation_worker():
logger = create_logger('moderator')
while True:
result = redis_moderation.blpop(['queue:msgs_to_check'])
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
logger.error(traceback.format_exc())
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))

View File

@ -1,25 +1,34 @@
import logging
import time
import traceback
from llm_server.routes.cache import redis
from llm_server.cluster.backend import get_running_models
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.logging import create_logger
from llm_server.routes.queue import priority_queue
logger = logging.getLogger('console_printer')
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
def console_printer():
logger = create_logger('console_printer')
time.sleep(3)
while True:
processing = redis.hkeys('processing_ips')
processing_count = 0
for ip in processing:
processing_count += int(redis.hget('processing_ips', ip))
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}')
try:
processing = redis.keys('active_gen_workers:http*') # backends always start with http
processing_count = 0
if len(processing):
for k in processing:
processing_count += redis.get(k, default=0, dtype=int)
backends = [k for k, v in cluster_config.all().items() if v['online']]
activity = priority_queue.activity()
# Calculate the queue size the same way it's done on the stats.
queue_size = 0
running_models = get_running_models()
for model in running_models:
queue_size += priority_queue.len(model)
# Active Workers and Processing should read the same. If not, that's an issue.
logger.info(f'Active Workers: {len([i for i in activity if (i[1] and i[1] != "waiting...")])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {len(backends)}')
except:
logger.error(traceback.format_exc())
time.sleep(10)

View File

@ -1,6 +1,6 @@
import time
from llm_server.routes.cache import redis
from llm_server.custom_redis import redis
def recent_prompters_thread():

View File

@ -0,0 +1,58 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.cluster.worker import cluster_worker
from llm_server.logging import create_logger
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers
from llm_server.workers.logger import db_logger
from llm_server.workers.mainer import main_background_thread
from llm_server.workers.moderator import start_moderation_workers
from llm_server.workers.printer import console_printer
from llm_server.workers.recenter import recent_prompters_thread
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)
def start_background():
logger = create_logger('threader')
start_workers(opts.cluster)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
logger.info('Started the main background thread.')
num_moderators = opts.cluster_workers * 3
start_moderation_workers(num_moderators)
logger.info(f'Started {num_moderators} moderation workers.')
t = Thread(target=cache_stats)
t.daemon = True
t.start()
logger.info('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
logger.info('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
logger.info('Started the console logger.infoer.')
t = Thread(target=cluster_worker)
t.daemon = True
t.start()
logger.info('Started the cluster worker.')
t = Thread(target=db_logger)
t.daemon = True
t.start()
logger.info('Started background logger.')

View File

@ -1,9 +0,0 @@
import time
from llm_server.routes.v1.generate_stats import generate_stats
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)

103
other/gradio/gradio_chat.py Normal file
View File

@ -0,0 +1,103 @@
import os
import sys
import time
import traceback
import warnings
from threading import Thread
import gradio as gr
import openai
import requests
warnings.filterwarnings("ignore")
API_BASE = os.getenv('API_BASE')
if not API_BASE:
print('Must set the secret variable API_BASE to your https://your-site/api')
sys.exit(1)
API_BASE = API_BASE.strip('/')
APP_TITLE = os.getenv('APP_TITLE')
PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE')
TRACKING_CODE = os.getenv('TRACKING_CODE')
def background():
while True:
previous = openai.api_base
try:
r = requests.get(API_BASE + '/stats').json()
if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys():
openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1'
else:
openai.api_base = API_BASE + '/openai/v1'
except:
traceback.print_exc()
openai.api_base = API_BASE + '/openai/v1'
if openai.api_base != previous:
print('Set primary model to', openai.api_base)
time.sleep(10)
if PRIMARY_MODEL_CHOICE:
t = Thread(target=background)
t.daemon = True
t.start()
print('Started the background thread.')
# A system prompt can be injected into the very first spot in the context.
# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE,
# the content in CONTEXT_TRIGGER_INJECTION will be injected.
# Setting CONTEXT_TRIGGER_PHRASE will also add it to the selectable examples.
CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE')
CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION')
openai.api_key = 'null'
openai.api_base = API_BASE + '/openai/v1'
def stream_response(prompt, history):
messages = []
do_injection = False
for human, assistant in history:
messages.append({'role': 'user', 'content': str(human)})
messages.append({'role': 'assistant', 'content': str(assistant)})
if CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in human:
do_injection = True
messages.append({'role': 'user', 'content': prompt})
if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt):
messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION})
try:
response = openai.ChatCompletion.create(
model='0',
messages=messages,
temperature=0,
max_tokens=300,
stream=True,
headers={'LLM-Source': 'huggingface-demo'}
)
except Exception:
raise gr.Error("Failed to reach inference endpoint.")
message = ''
for chunk in response:
if len(chunk['choices'][0]['delta']) != 0:
message += chunk['choices'][0]['delta']['content']
yield message
examples = ["hello"]
if CONTEXT_TRIGGER_PHRASE:
examples.insert(0, CONTEXT_TRIGGER_PHRASE)
with gr.Blocks(analytics_enabled=False) as demo:
gr.ChatInterface(stream_response, examples=examples, title=APP_TITLE, analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}')
if TRACKING_CODE:
print('Inserting tracking code')
gr.HTML(TRACKING_CODE)
demo.queue(concurrency_count=1, api_open=False).launch(show_api=False)

View File

@ -0,0 +1,3 @@
gradio
openai
requests

View File

@ -1,3 +1,8 @@
"""
This file is used to run certain tasks when the HTTP server starts.
It's located here so it doesn't get imported with daemon.py
"""
try:
import gevent.monkey

38
other/nginx-site.conf Normal file
View File

@ -0,0 +1,38 @@
server
{
listen 443 ssl http2 default_server;
server_name _;
proxy_set_header Host $host;
proxy_set_header Connection $http_connection;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Scheme $scheme;
location ~* ^/api/(.*?|v1|openai)/(v1|(generate|stream)|(chat/completions|completions))$
{
# Route to inference endpoints
proxy_pass http://127.0.0.1:5000;
# Required for streaming (both websockets and SSE).
proxy_buffering off;
proxy_cache off;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
# Set long timeouts for inference operations.
# Cloudflare has a timeout of 100 seconds.
proxy_read_timeout 120;
proxy_connect_timeout 120;
proxy_send_timeout 120;
}
location /
{
proxy_pass http://127.0.0.1:5000;
}
ssl_certificate /etc/ssl/certs/nginx-selfsigned.crt;
ssl_certificate_key /etc/ssl/private/nginx-selfsigned.key;
include /etc/nginx/snippets/ssl-params.conf;
}

11
other/tests/config.sh Normal file
View File

@ -0,0 +1,11 @@
HOST="proxy.chub-archive.evulid.cc"
AUTH_KEY="TEST_1df979f0-6df1-41bd-814a-e99b1680e727"
PROXY_SERVERS=(
"http://172.0.4.7:3128"
"http://172.0.4.8:3128"
"http://172.0.4.10:3128"
"http://172.0.4.12:3128"
"http://172.0.4.13:3128"
)

58
other/tests/generate.sh Executable file
View File

@ -0,0 +1,58 @@
#!/bin/bash
SLEEP_TIME=2
while getopts p:t: flag; do
case "${flag}" in
p) PROXY_CHOICE=${OPTARG} ;;
t) SLEEP_TIME=${OPTARG} ;;
*) ;;
esac
done
SOURCE=${BASH_SOURCE[0]}
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
SOURCE=$(readlink "$SOURCE")
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
source "$DIR/config.sh"
if [ -n "$PROXY_CHOICE" ]; then
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
echo "Using $our_proxy_server"
else
our_proxy_server=""
fi
while true; do
echo "--> START <--"
DATA=$(
cat <<EOF
{
"prompt": "Write a 300 word story about an apple tree.",
"temperature": 1,
"max_new_tokens": 100,
"top_p": 1.0,
"top_k": -1,
"use_beam_search": false,
"stop": ["TEST"],
"ignore_eos": false,
"presence_penalty": 0.0,
"frequency_penalty": 0.0,
"length_penalty": 1.0,
"early_stopping": false
}
EOF
)
curl "https://$HOST/api/v1/generate" -m 100 -x "$our_proxy_server" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $AUTH_KEY" \
-d "$DATA"
echo -e "--> DONE <--\n"
sleep $SLEEP_TIME
done

View File

@ -0,0 +1,52 @@
#!/bin/bash
DO_STREAM=false
SLEEP_TIME=2
while getopts p:t:s flag; do
case "${flag}" in
s) DO_STREAM=true ;;
p) PROXY_CHOICE=${OPTARG} ;;
t) SLEEP_TIME=${OPTARG} ;;
*) ;;
esac
done
SOURCE=${BASH_SOURCE[0]}
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
SOURCE=$(readlink "$SOURCE")
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
source "$DIR/config.sh"
if [ ! -z "$PROXY_CHOICE" ]; then
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
echo "Using $our_proxy_server"
else
our_proxy_server=""
fi
while true; do
echo "--> START <--"
DATA=$(
cat <<EOF
{
"model": "gpt-4",
"messages": [{"role": "user", "content": "Write a 300 word story about an apple tree."}],
"max_tokens": 100,
"stream": $DO_STREAM
}
EOF
)
curl "https://$HOST/api/openai/v1/chat/completions" -m 100 -x "$our_proxy_server" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $AUTH_KEY" \
-d "$DATA"
echo -e "--> DONE <--\n"
sleep $SLEEP_TIME
done

52
other/tests/oai-completion.sh Executable file
View File

@ -0,0 +1,52 @@
#!/bin/bash
DO_STREAM=false
SLEEP_TIME=2
while getopts p:t:s flag; do
case "${flag}" in
s) DO_STREAM=true ;;
p) PROXY_CHOICE=${OPTARG} ;;
t) SLEEP_TIME=${OPTARG} ;;
*) ;;
esac
done
SOURCE=${BASH_SOURCE[0]}
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
SOURCE=$(readlink "$SOURCE")
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
source "$DIR/config.sh"
if [ ! -z "$PROXY_CHOICE" ]; then
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
echo "Using $our_proxy_server"
else
our_proxy_server=""
fi
while true; do
echo "--> START <--"
DATA=$(
cat <<EOF
{
"model": "gpt-4",
"prompt": "Write a 300 word story about an apple tree.",
"max_tokens": 100,
"stream": $DO_STREAM
}
EOF
)
curl "https://$HOST/api/openai/v1/completions" -m 100 -x "$our_proxy_server" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $AUTH_KEY" \
-d "$DATA"
echo -e "--> DONE <--\n"
sleep $SLEEP_TIME
done

64
other/tests/start-bulk.sh Executable file
View File

@ -0,0 +1,64 @@
#!/bin/bash
# Function to display help message
function display_help {
echo "Usage: $0 -n num_windows -c command"
echo
echo " -n, --number Number of windows to create"
echo " -c, --command Command to run in each window"
echo
exit 1
}
# Parse command line arguments
while getopts "n:c:h" opt; do
case ${opt} in
n)
num_windows=${OPTARG}
;;
c)
command=${OPTARG}
;;
h)
display_help
;;
\?)
echo "Invalid option: -$OPTARG" 1>&2
display_help
;;
:)
echo "Option -$OPTARG requires an argument." 1>&2
display_help
;;
esac
done
# Check if number of windows and command are provided
if [ -z "$num_windows" ] || [ -z "$command" ]; then
echo "Both number of windows and command are required."
display_help
fi
# Calculate rows and columns
rows=$(echo "sqrt($num_windows)" | bc)
columns=$(echo "($num_windows + $rows - 1) / $rows" | bc)
# Create a new tmux session
tmux new-session -d -s llm_tester "$command -p 0"
# Create the remaining windows
for ((i = 1; i < $num_windows; i++)); do
if ((i % $columns == 0)); then
tmux select-layout -t llm_tester:0 tiled
tmux select-pane -t 0
tmux split-window -t llm_tester:0 -v "$command -p $i"
else
tmux split-window -t llm_tester:0 -h "$command -p $i"
fi
done
# Balance the windows
tmux select-layout -t llm_tester:0 tiled
# Attach to the session
tmux attach-session -t llm_tester

62
other/ooba-test-streaming.py → other/tests/stream.py Normal file → Executable file
View File

@ -1,37 +1,50 @@
import asyncio
import json
import sys
import os
import time
from pathlib import Path
try:
import websockets
except ImportError:
print("Websockets package not found. Make sure it's installed.")
# For local streaming, the websockets are hosted without ssl - ws://
HOST = 'localhost:5000'
URI = f'ws://{HOST}/api/v1/stream'
script_path = os.path.dirname(os.path.realpath(__file__))
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
def parse_bash_config(file_path):
config = {}
with open(file_path, 'r') as f:
for line in f:
if line.startswith('#') or '=' not in line:
continue
key, value = line.strip().split('=', 1)
if value.startswith('"') and value.endswith('"'):
value = value[1:-1]
elif value.startswith('(') and value.endswith(')'):
value = value[1:-1].split()
value = [v.strip('"') for v in value]
config[key] = value
return config
config = parse_bash_config(Path(script_path, 'config.sh'))
async def run(context):
# Note: the selected defaults change from time to time.
request = {
'prompt': context,
'max_new_tokens': 250,
'auto_max_new_tokens': False,
'max_tokens_second': 0,
# Generation params. If 'preset' is set to different than 'None', the values
# in presets/preset-name.yaml are used instead of the individual numbers.
'preset': 'None',
'do_sample': True,
'temperature': 0.7,
'top_p': 0.1,
'typical_p': 1,
'epsilon_cutoff': 0, # In units of 1e-4
'eta_cutoff': 0, # In units of 1e-4
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'tfs': 1,
'top_a': 0,
'repetition_penalty': 1.18,
@ -48,7 +61,6 @@ async def run(context):
'mirostat_eta': 0.1,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'truncation_length': 2048,
@ -58,7 +70,7 @@ async def run(context):
'stopping_strings': []
}
async with websockets.connect(URI, ping_interval=None) as websocket:
async with websockets.connect(f'wss://{config["HOST"]}/api/v1/stream', ping_interval=None) as websocket:
await websocket.send(json.dumps(request))
yield context # Remove this if you just want to see the reply
@ -67,20 +79,28 @@ async def run(context):
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
print(incoming_data)
match incoming_data['event']:
case 'text_stream':
yield incoming_data['text']
# case 'text_stream':
# yield incoming_data['text']
case 'stream_end':
return
async def print_response_stream(prompt):
async for response in run(prompt):
print(response, end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
print('\n\nfinished')
try:
async for response in run(prompt):
print(response, end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
except Exception as e:
print(e)
if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)"
asyncio.run(print_response_stream(prompt))
prompt = "Write a 300 word story about an apple tree.\n\n"
while True:
print('--> START <--')
asyncio.run(print_response_stream(prompt))
print('--> DONE <--')
time.sleep(2)

View File

@ -1,15 +0,0 @@
**A Docker container for running VLLM on Paperspace Gradient notebooks.**
1. Run `jupyter server --generate-config` and `jupyter server password` on your local machine, then copy Jupyter's config directory to `./jupyter`
2. Place your Rathole client config at `./rathole-client.toml`
3. `docker build . -t "paperspace-vllm"`
To test on your local machine, run this command:
```bash
docker run --shm-size 14g --gpus all \
-v /storage/models/awq/MythoMax-L2-13B-AWQ:/models/MythoMax-L2-13B-AWQ \
-p 7000:7000 -p 8888:8888 \
-e API_SERVER_ARGS="--model /models/MythoMax-L2-13B-AWQ --quantization awq --max-num-batched-tokens 99999 --gpu-memory-utilization 1" \
vllm-cloud
```

View File

@ -1,87 +1,50 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as build
RUN apt-get update && \
apt-get install -y git python3-pip python3-venv wget unzip && \
rm -rf /var/lib/apt/lists/*
RUN pip3 install --upgrade pip setuptools wheel
RUN git clone https://git.evulid.cc/cyberes/local-llm-server.git /local-llm-server
WORKDIR /local-llm-server
RUN python3 -m venv /venv
RUN /venv/bin/pip install git+https://github.com/vllm-project/vllm
RUN python3 -m venv /jupyterlab
RUN /jupyterlab/bin/pip install jupyterlab
RUN /jupyterlab/bin/jupyter labextension disable "@jupyterlab/apputils-extension:announcements"
RUN mkdir -p /app
RUN wget https://github.com/rapiz1/rathole/releases/download/v0.4.8/rathole-x86_64-unknown-linux-gnu.zip -O /tmp/rathole.zip
RUN unzip -j /tmp/rathole.zip -d /tmp
RUN rm /tmp/rathole.zip
RUN cp /tmp/rathole /app
# The local local-llm-server repo may be cached, so we will fetch and reset to the remote every time.
# Also, make sure there weren't any pip deps added.
ADD "https://www.random.org/cgi-bin/randbyte?nbytes=10&format=h" skipcache
RUN git fetch; git reset --hard origin/master
RUN /venv/bin/pip install -r requirements.txt
FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as runtime
RUN apt-get update && apt-get install -y supervisor && rm -rf /var/lib/apt/lists/*
FROM cyberes/vllm-paperspace-base as runtime
RUN useradd -ms /bin/bash apiserver
RUN usermod -s /bin/bash root
# Required packages
RUN apt-get update && \
apt-get install -y python3 python3-pip wget aria2 git-lfs git openssh-server openssh-client nano tmux file && \
apt-get install -y python3 python3-pip supervisor && \
rm -rf /var/lib/apt/lists/*
RUN pip3 install --upgrade pip setuptools wheel
# Useful Python packages
RUN pip3 install glances
# Useful tools
RUN apt-get update && \
apt-get install -y wget aria2 git-lfs git openssh-server openssh-client nano tmux file && \
rm -rf /var/lib/apt/lists/*
RUN pip3 install --upgrade pip setuptools wheel
RUN pip3 install glances
# Update the git repo
RUN cd /local-llm-server && git reset --hard && git pull
# Enable root SSH login
RUN sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config
# Disable password SSH login
RUN sed -i 's/#PasswordAuthentication yes/PasswordAuthentication no/' /etc/ssh/sshd_config
# Create the necessary directory for SSH
# Create the necessary directory for sshd
RUN mkdir /var/run/sshd
ADD "https://www.random.org/cgi-bin/randbyte?nbytes=10&format=h" skipcache
COPY --from=build /local-llm-server /local-llm-server
COPY --from=build /venv /venv
COPY --from=build /app /app
COPY --from=build /jupyterlab /jupyterlab
RUN cp /local-llm-server/other/vllm/Docker/supervisord.conf /etc/supervisor/conf.d/supervisord.conf
RUN cp /local-llm-server/other/vllm/Docker/start-vllm.sh /app/start-vllm.sh
RUN cp /local-llm-server/other/vllm/Docker/start-container.sh /app/start.sh
# Copy your secrets in
# COPY ./jupyter /app/jupyter
COPY supervisord.conf /etc/supervisor/supervisord.conf
COPY start-vllm.sh /app/start-vllm.sh
COPY init-container.sh /app/init.sh
COPY start-container.sh /app/start.sh
RUN mkdir -p /var/log/app/
RUN chown -R apiserver:apiserver /local-llm-server && \
chown -R apiserver:apiserver /app && \
chown -R apiserver:apiserver /var/log/app/
RUN git config --global --add safe.directory /local-llm-server
RUN chmod +x /app/init.sh
RUN chmod +x /app/start.sh
ENV SHELL="/bin/bash"
# SSH
EXPOSE 22
# VLLM
EXPOSE 7000
# Jupyter
# Expose Jupyter. We don't need to expose VLLM or SSH since rathole will tunnel those.
EXPOSE 8888
CMD /app/start.sh

View File

@ -0,0 +1,43 @@
# This container builds and assembles the Python parts of the Docker container.
# It is used as the base for the resulting container, which avoids having to re-push
# the large PyTorch parts every time the application is rebuilt.
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 as build
RUN apt-get update && \
apt-get install -y git python3-pip python3-venv wget unzip && \
rm -rf /var/lib/apt/lists/*
RUN pip install --upgrade pip setuptools wheel
RUN git clone https://git.evulid.cc/cyberes/local-llm-server.git /local-llm-server
RUN python3 -m venv /jupyterlab
RUN /jupyterlab/bin/pip install jupyterlab
RUN /jupyterlab/bin/jupyter labextension disable "@jupyterlab/apputils-extension:announcements"
RUN mkdir -p /app
RUN wget https://github.com/rapiz1/rathole/releases/download/v0.4.8/rathole-x86_64-unknown-linux-gnu.zip -O /tmp/rathole.zip
RUN unzip -j /tmp/rathole.zip -d /tmp
RUN rm /tmp/rathole.zip
RUN cp /tmp/rathole /app
RUN python3 -m venv /venv
RUN /venv/bin/pip3 install --upgrade pip setuptools wheel
# Install PyTorch before installing VLLM to ensure we use the right version for our CUDA install.
RUN wget -q -O - https://raw.githubusercontent.com/vllm-project/vllm/main/requirements.txt | grep -E 'torch*' > /tmp/torch_version
RUN /venv/bin/pip3 install "$(cat /tmp/torch_version)" --index-url https://download.pytorch.org/whl/cu118
# WORKDIR /local-llm-server
# Don't build VLLM because we don't do that on the inference server. Just install from pip.
# RUN /venv/bin/pip install git+https://github.com/vllm-project/vllm
RUN /venv/bin/pip install vllm
FROM nvidia/cuda:11.8.0-base-ubuntu22.04 as base
COPY --from=build /local-llm-server /local-llm-server
COPY --from=build /venv /venv
COPY --from=build /app /app
COPY --from=build /jupyterlab /jupyterlab

View File

@ -0,0 +1,47 @@
**A Docker container for running VLLM on Paperspace Gradient notebooks.**
### Running
1. In Paperspace, create a new notebook.
2. Click `Start from Scratch`.
3. Select your GPU and set the auto-shutdown timeout to 6 hours.
4. Click the `View Advanced Options` button at the bottom of the page. Enter these details in the form that appears:
- Container Name: `cyberes/vllm-paperspace:latest`
- Container Command: `/app/start.sh`
5. Start the notebook. It may take up to five minutes for them to pull and start the custom image.
6. Once the container is started, open the log viewer by clicking the icon in the bottom left of the screen. You should see errors from rathole and VLLM as a result of the blank config files. The container will create a new directory in your mounted
storage: `/storage/vllm/`.
7. Enter your rathole client config in `/storage/vllm/rathole-client.toml`. If you need a visual text editor, first link the directory back to the Jupyter home: `ln -s /storage/vllm /notebooks`
8. Restart rathole with `supervisorctl restart rathole` and then view the log: `tail -f /var/log/app/rathole.log`. If you see lines that start with `INFO` and end with `Control channel established`, rathole has connected and is working. Error mesasges will begin
with `ERROR`.
9. Download an AWQ quantization from [TheBloke](https://huggingface.co/TheBloke) to `/storage/vllm/models/`.
10. Enter your VLLM commandline args in `/storage/vllm/cmd.txt`. You need to set `--model` to the path of the model you want to load.
11. Restart VLLM with `supervisorctl restart vllm` and then view the log: `tail -f /var/log/app/vllm.log`. It may take up to three minutes to load. When you see the line:
```
INFO: Uvicorn running on http://0.0.0.0:7000 (Press CTRL+C to quit)
```
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;VLLM is running and ready for queries.
12. In `/notebooks` (the home directory of Jupyter), the notebook `idle.ipynb` will automatically be created. Run this notebook so Paperspace does not shut down your machine due to "inactivity". You **must** keep the running notebook open in a
browser tab.
### Building
You **must** have a GPU attached to your system when building the container (required for building VLLM).
1. Install the [NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html) and CUDA 11.8.
2. `bash build-docker.sh`
To run the container on your local machine:
```bash
sudo docker run -it --shm-size 14g --gpus all -v /home/user/testing123/notebooks:/notebooks -v /home/user/testing123/storage:/storage -p 8888:8888 cyberes/vllm-paperspace:latest
```
You will need to create a directory to mount inside the container (for example: `/home/user/testing123/`). Within this should be the folder `models` that holds the model to load, `rathole-client.toml`, and `cmd.txt`.
If you need to debug something, you can start a shell inside the container:
```bash
sudo docker run -it --shm-size 14g --gpus all -v /home/user/testing123/notebooks:/notebooks -v /home/user/testing123/storage:/storage -p 8888:8888 --entrypoint bash cyberes/vllm-paperspace:latest
```

View File

@ -0,0 +1,7 @@
#!/bin/bash
# Build and push the container.
git pull || exit
sudo docker build . -f Dockerfile.base -t cyberes/vllm-paperspace-base --no-cache && sudo docker push cyberes/vllm-paperspace-base:latest || exit
sudo docker build . -t cyberes/vllm-paperspace && sudo docker push cyberes/vllm-paperspace:latest

View File

@ -0,0 +1,40 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "49ae6555-572b-4463-ba01-cc4331932a6c",
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"i = 0\n",
"while True:\n",
" print(i)\n",
" i += 1\n",
" time.sleep(1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,22 @@
#!/bin/bash
# Create the required directories and files.
echo "SETTING UP FILE SYSTEM..."
mkdir -p /storage/vllm/
chown -R apiserver:apiserver /storage/vllm
touch /storage/vllm/cmd.txt
touch /storage/vllm/rathole-client.toml
# The user can store SSH auth and authorized_keys to streamline SSH login.
if [ -f /storage/vllm/ssh ]; then
cp -r /storage/vllm/ssh /root/.ssh
echo "Copied ssh from /storage"
fi
# If the user has not created the VLLM commandline arg file, create the default.
if [ ! -f /storage/vllm/cmd.txt ]; then
echo "--max-num-batched-tokens 4098 --quantization awq --model /storage/vllm/models/model-path" >/storage/vllm/cmd.txt
fi
# Copy the idling notebook to storage. This will create a blank notebook every time the container is started.
cp /local-llm-server/other/vllm/Docker/idle.ipynb /notebooks/idle.ipynb

View File

@ -1,13 +1,4 @@
#!/bin/bash
mkdir -p /storage/vllm/
chown -R apiserver:apiserver /storage/vllm
touch /storage/vllm/cmd.txt
touch /storage/vllm/rathole-client.toml
if [ -f /storage/vllm/ssh ]; then
cp -r /storage/vllm/ssh /root/.ssh
echo "Copied ssh from /storage"
fi
/usr/bin/supervisord
# Start the services and launch the container.
/usr/bin/supervisord -c /etc/supervisor/supervisord.conf

View File

@ -6,9 +6,4 @@ for pid in $vllm_pid; do
kill -9 $pid
done
cd /local-llm-server
git fetch
git reset --hard origin/master
/venv/bin/pip install -r requirements.txt
/venv/bin/python /local-llm-server/other/vllm/vllm_api_server.py --host 0.0.0.0 --port 7000 --max-log-len 100 $(cat /storage/vllm/cmd.txt)

View File

@ -1,5 +1,25 @@
[supervisord]
nodaemon=true
nodaemon = true
user=root
pidfile = /var/run/supervisord.pid
logfile = /var/log/app/supervisord.log
directory = /tmp
[unix_http_server]
file=/var/run/supervisor.sock
chmod=0770
[rpcinterface:supervisor]
supervisor.rpcinterface_factory = supervisor.rpcinterface:make_main_rpcinterface
[supervisorctl]
serverurl=unix:///var/run/supervisor.sock
[program:startup]
command=/app/init.sh
autostart=true
autorestart=false
startsecs=0
[program:vllm]
command=/bin/bash -c 'bash /app/start-vllm.sh 2>&1 | tee -a /var/log/app/vllm.log'
@ -24,9 +44,20 @@ user=apiserver
environment=HOME="/home/apiserver",USER="apiserver"
[program:jupyter]
command=/jupyterlab/bin/jupyter lab --allow-root --ip=0.0.0.0 --no-browser --ServerApp.trust_xheaders=True --ServerApp.disable_check_xsrf=False --ServerApp.allow_remote_access=True --ServerApp.allow_origin='*' --ServerApp.allow_credentials=True
command=/jupyterlab/bin/jupyter lab --allow-root --ip=0.0.0.0 --no-browser --ServerApp.trust_xheaders=True --ServerApp.disable_check_xsrf=False --ServerApp.allow_remote_access=True --ServerApp.allow_origin='*' --ServerApp.allow_credentials=True --notebook-dir /notebooks
environment=SHELL="/bin/bash"
; JUPYTER_CONFIG_DIR="/app/jupyter"
autostart=true
autorestart=true
stdout_logfile=/dev/fd/1
stdout_logfile_maxbytes=0
stderr_logfile=/dev/fd/2
stderr_logfile_maxbytes=0
[program:ssh]
command=/usr/sbin/sshd -D
autostart=true
autorestart=true
stdout_logfile=/dev/fd/1
stdout_logfile_maxbytes=0
stderr_logfile=/dev/fd/2
stderr_logfile_maxbytes=0

View File

@ -0,0 +1,11 @@
#!/bin/bash
# Run this script to update the container.
# Will restart VLLM as well.
cd /local-llm-server || exit
git fetch
git reset --hard origin/master
supervisorctl restart vllm

View File

@ -1,21 +1,18 @@
flask~=2.3.3
flask_cors
pyyaml~=6.0.1
flask_caching
Flask-Caching==2.0.2
requests~=2.31.0
tiktoken~=0.5.0
gunicorn
gevent~=23.9.0.post1
async-timeout
flask-sock
uvicorn~=0.23.2
fastapi~=0.103.1
torch~=2.0.1
PyMySQL~=1.1.0
DBUtils~=3.0.3
simplejson~=3.19.1
websockets~=11.0.3
basicauth~=1.0.0
openai~=0.28.0
urllib3~=2.0.4
celery[redis]
flask-sock==0.6.0
gunicorn==21.2.0
redis==5.0.1
ujson==5.8.0
vllm==0.2.1.post1
gradio~=3.46.1
coloredlogs~=15.0.1

138
server.py
View File

@ -1,5 +1,3 @@
from llm_server.config.config import mode_ui_names
try:
import gevent.monkey
@ -7,37 +5,46 @@ try:
except ImportError:
pass
from llm_server.pre_fork import server_startup
from llm_server.config.load import load_config
import os
import sys
from pathlib import Path
import simplejson as json
from flask import Flask, jsonify, render_template, request
from flask import Flask, jsonify, render_template, request, Response
import llm_server
import config
from llm_server import opts
from llm_server.cluster.backend import get_model_choices
from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.config import mode_ui_names
from llm_server.config.load import load_config
from llm_server.custom_redis import flask_cache, redis
from llm_server.database.conn import database
from llm_server.database.create import create_db
from llm_server.llm import get_token_count
from llm_server.routes.openai import openai_bp
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.pre_fork import server_startup
from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.sock import init_wssocket
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
# TODO: implement background thread to test backends via sending test prompts
# TODO: if backend fails request, mark it as down
# TODO: allow setting concurrent gens per-backend
# TODO: set the max tokens to that of the lowest backend
# TODO: implement RRD backend loadbalancer option
# TODO: have VLLM reject a request if it already has n == concurrent_gens running
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: use coloredlogs
# TODO: need to update opts. for workers
# TODO: seperate queue item timeout for websockets (make longer, like 5 minutes)
# TODO: return an `error: True`, error code, and error message rather than just a formatted message
# TODO: what happens when all backends are offline? What about the "online" key in the stats page?
# TODO: redis SCAN vs KEYS??
# TODO: is frequency penalty the same as ooba repetition penalty???
# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
# Lower priority
# TODO: if a backend is at its limit of concurrent requests, choose a different one
# TODO: make error messages consitient
# TODO: support logit_bias on OpenAI and Ooba endpoints.
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# TODO: validate openai_silent_trim works as expected and only when enabled
# TODO: rewrite config storage. Store in redis so we can reload it.
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
# TODO: the estiamted wait time lags behind the stats
# TODO: simulate OpenAI error messages regardless of endpoint
@ -59,19 +66,16 @@ except ModuleNotFoundError as e:
print('Please see README.md for install instructions.')
sys.exit(1)
import config
from llm_server import opts
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import RedisWrapper, flask_cache
from llm_server.llm import redis
from llm_server.routes.stats import get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats
app = Flask(__name__)
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
# Fixes ConcurrentObjectUseError
# https://github.com/miguelgrinberg/simple-websocket/issues/24
app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
app.register_blueprint(bp, url_prefix='/api/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
init_wssocket(app)
flask_cache.init_app(app)
flask_cache.clear()
@ -82,18 +86,13 @@ if config_path_environ:
else:
config_path = Path(script_path, 'config', 'config.yml')
success, config, msg = load_config(config_path, script_path)
success, config, msg = load_config(config_path)
if not success:
print('Failed to load config:', msg)
sys.exit(1)
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db()
llm_server.llm.redis = RedisWrapper('local_llm')
create_db()
# print(app.url_map)
@app.route('/')
@ -101,20 +100,30 @@ create_db()
@app.route('/api/openai')
@flask_cache.cached(timeout=10)
def home():
base_client_api = redis.get('base_client_api', dtype=str)
stats = generate_stats()
model_choices, default_model = get_model_choices()
if not stats['online']:
running_model = estimated_wait_sec = 'offline'
else:
running_model = redis.get('running_model', str, 'ERROR')
if default_model:
if not model_choices.get(default_model):
return 'The server is still starting up. Please wait...'
active_gen_workers = get_active_gen_workers()
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
default_model_info = model_choices[default_model]
if default_model_info['queued'] == 0 and default_model_info['queued'] >= default_model_info['concurrent_gens']:
# There will be a wait if the queue is empty but prompts are processing, but we don't
# know how long.
estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"
default_estimated_wait_sec = f"less than {int(default_model_info['estimated_wait'])} seconds"
else:
estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds"
default_estimated_wait_sec = f"{int(default_model_info['estimated_wait'])} seconds"
else:
default_model_info = {
'model': 'OFFLINE',
'processing': '-',
'queued': '-',
'context_size': '-',
}
default_estimated_wait_sec = 'OFFLINE'
if len(config['analytics_tracking_code']):
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
@ -127,32 +136,47 @@ def home():
info_html = ''
mode_info = ''
if opts.mode == 'vllm':
mode_info = vllm_info
base_client_api = redis.get('base_client_api', str)
for k, v in cluster_config.all().items():
if v['mode'] == 'vllm':
mode_info = vllm_info
break
return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name,
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
default_model=default_model_info['model'],
default_active_gen_workers=default_model_info['processing'],
default_proompters_in_queue=default_model_info['queued'],
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
client_api=f'https://{base_client_api}',
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
estimated_wait=estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1],
streaming_input_textbox=mode_ui_names[opts.mode][2],
context_size=opts.context_size,
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled',
default_estimated_wait=default_estimated_wait_sec,
mode_name=mode_ui_names[opts.frontend_api_mode][0],
api_input_textbox=mode_ui_names[opts.frontend_api_mode][1],
streaming_input_textbox=mode_ui_names[opts.frontend_api_mode][2],
default_context_size=default_model_info['context_size'],
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info,
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
expose_openai_system_prompt=opts.expose_openai_system_prompt,
enable_streaming=opts.enable_streaming,
model_choices=model_choices,
proompters_5_min=stats['stats']['proompters']['5_min'],
proompters_24_hrs=stats['stats']['proompters']['24_hrs'],
)
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
@app.route('/robots.txt')
def robots():
# TODO: have config value to deny all
# TODO: https://developers.google.com/search/docs/crawling-indexing/robots/create-robots-txt
t = """User-agent: *
Allow: /"""
r = Response(t)
r.headers['Content-Type'] = 'text/plain'
return r
@app.route('/<first>')
@app.route('/<first>/<path:rest>')

View File

@ -65,6 +65,19 @@
.hidden {
display: none;
}
.header-workers {
font-weight: normal;
font-size: 14pt;
}
h3 {
font-size: 16pt;
}
.no-marker {
list-style: none;
}
</style>
</head>
@ -76,8 +89,12 @@
<h1 style="text-align: center;margin-top: 0;">{{ llm_middleware_name }}</h1>
<div class="info-box">
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
<p><strong>Current Model:</strong> <span id="model">{{ default_model }}</span></p>
<p>
<strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ default_estimated_wait }}</span><br>
Processing: {{ default_active_gen_workers }}<br>
Queued: {{ default_proompters_in_queue }}
</p>
<br>
<p><strong>Client API URL:</strong> {{ client_api }}</p>
<p><strong>Streaming API URL:</strong> {{ ws_client_api if enable_streaming else 'Disabled' }}</p>
@ -91,17 +108,20 @@
<br>
<div class="info-box">
<div id="oobabooga">
<strong>Instructions:</strong>
<h3>Instructions</h3>
<div id="instructions">
<ol>
<li>In Settings > Power User Options, enable <kbd>Relaxed API URLS</kbd>.</li>
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
{% if enable_streaming %}<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>{% endif %}
{% if enable_streaming %}
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
{% endif %}
<li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
API key</kbd> textbox.
</li>
<li>Click <kbd>Connect</kbd> to test the connection.</li>
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ default_context_size }}.</li>
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>
</li>
</ol>
@ -120,13 +140,45 @@
<br>
<div class="info-box">
<pre><code class="language-json" style="background-color: white">{{ stats_json|safe }}</code></pre>
<h3>Statistics</h3>
Proompters:
<ul style="margin-top: 5px;">
<li class="no-marker">5 minutes: {{ proompters_5_min }}</li>
<li class="no-marker">24 hours: {{ proompters_24_hrs }}</li>
</ul>
</div>
<br>
{% for key, value in model_choices.items() %}
<div class="info-box">
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}worker{% else %}workers{% endif %}</span></h3>
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
{% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %}
{% else %}
{% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %}
{% endif %}
<p>
<strong>Estimated Wait Time:</strong> {{ estimated_wait_sec }}<br>
Processing: {{ value.processing }}<br>
Queued: {{ value.queued }}<br>
</p>
<p>
<strong>Client API URL:</strong> {{ value.client_api }}<br>
<strong>Streaming API URL:</strong> {{ value.ws_client_api }}<br>
<strong>OpenAI-Compatible API URL:</strong> {{ value.openai_client_api }}
</p>
<p><strong>Context Size:</strong> {{ value.context_size }}</p>
<p><strong>Average Generation Time:</strong> {{ value.avg_generation_time | int }} seconds</p>
</div>
<br>
{% endfor %}
</div>
<div class="footer">
<a href="https://git.evulid.cc/cyberes/local-llm-server" target="_blank">git.evulid.cc/cyberes/local-llm-server</a>
</div>
<script>hljs.highlightAll();</script>
</body>
</html>