Merge cluster to master (#3)
Co-authored-by: Cyberes <cyberes@evulid.cc> Reviewed-on: #3
This commit is contained in:
parent
561820fb9e
commit
0059e7956c
|
@ -43,7 +43,9 @@ To set up token auth, add rows to the `token_auth` table in the SQLite database.
|
|||
|
||||
### Use
|
||||
|
||||
If you see unexpected errors in the console, make sure `daemon.py` is running or else the required data will be missing from Redis. You may need to wait a few minutes for the daemon to populate the database.
|
||||
|
||||
Flask may give unusual errors when running `python server.py`. I think this is coming from Flask-Socket. Running with Gunicorn seems to fix the issue: `gunicorn -b :5000 --worker-class gevent server:app`
|
||||
|
||||
### To Do
|
||||
|
||||
|
|
68
daemon.py
68
daemon.py
|
@ -1,22 +1,19 @@
|
|||
import time
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
gevent.monkey.patch_all()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.database.create import create_db
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.workers.app import start_background
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.config.load import load_config, parse_backends
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.logging import create_logger, logging_info, init_logging
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.threader import start_background
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
|
@ -26,19 +23,46 @@ else:
|
|||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
|
||||
if __name__ == "__main__":
|
||||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
parser = argparse.ArgumentParser(description='Daemon microservice.')
|
||||
parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.")
|
||||
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
|
||||
args = parser.parse_args()
|
||||
|
||||
success, config, msg = load_config(config_path, script_path)
|
||||
# TODO: have this be set by either the arg or a config value
|
||||
if args.debug:
|
||||
logging_info.level = logging.DEBUG
|
||||
|
||||
init_logging()
|
||||
logger = create_logger('daemon')
|
||||
logger.debug('Debug logging enabled.')
|
||||
|
||||
if not args.no_reset:
|
||||
Redis().flushall()
|
||||
logger.info('Flushed Redis.')
|
||||
|
||||
success, config, msg = load_config(config_path)
|
||||
if not success:
|
||||
print('Failed to load config:', msg)
|
||||
logger.info(f'Failed to load config: {msg}')
|
||||
sys.exit(1)
|
||||
|
||||
create_db()
|
||||
|
||||
cluster_config.clear()
|
||||
cluster_config.load(parse_backends(config))
|
||||
|
||||
logger.info('Loading backend stats...')
|
||||
generate_stats(regen=True)
|
||||
|
||||
start_background()
|
||||
|
||||
redis.set('daemon_started', 1)
|
||||
print('== Daemon Setup Complete ==\n')
|
||||
# Give some time for the background threads to get themselves ready to go.
|
||||
time.sleep(2)
|
||||
|
||||
while True:
|
||||
time.sleep(3600)
|
||||
redis.set('daemon_started', 1)
|
||||
logger.info('== Daemon Setup Complete ==')
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(3600)
|
||||
except KeyboardInterrupt:
|
||||
redis.set('daemon_started', 0)
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
import numpy as np
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.llm.info import get_info
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers_model
|
||||
|
||||
|
||||
def get_backends_from_model(model_name: str):
|
||||
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
|
||||
|
||||
|
||||
def get_running_models():
|
||||
return redis_running_models.keys()
|
||||
|
||||
|
||||
def purge_backend_from_running_models(backend_url: str):
|
||||
keys = redis_running_models.keys()
|
||||
pipeline = redis_running_models.pipeline()
|
||||
for model in keys:
|
||||
pipeline.srem(model, backend_url)
|
||||
pipeline.execute()
|
||||
|
||||
|
||||
def is_valid_model(model_name: str):
|
||||
return redis_running_models.exists(model_name)
|
||||
|
||||
|
||||
def test_backend(backend_url: str, test_prompt: bool = False):
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
if test_prompt:
|
||||
handler = VLLMBackend(backend_url)
|
||||
parameters, _ = handler.get_parameters({
|
||||
"stream": False,
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 3,
|
||||
})
|
||||
data = {
|
||||
'prompt': 'test prompt',
|
||||
**parameters
|
||||
}
|
||||
try:
|
||||
success, response, err = generator(data, backend_url, timeout=10)
|
||||
if not success or not response or err:
|
||||
return False, {}
|
||||
except:
|
||||
return False, {}
|
||||
i = get_info(backend_url, backend_info['mode'])
|
||||
if not i.get('model'):
|
||||
return False, {}
|
||||
return True, i
|
||||
|
||||
|
||||
def get_model_choices(regen: bool = False):
|
||||
if not regen:
|
||||
c = redis.getp('model_choices')
|
||||
if c:
|
||||
return c
|
||||
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
running_models = get_running_models()
|
||||
model_choices = {}
|
||||
for model in running_models:
|
||||
b = get_backends_from_model(model)
|
||||
|
||||
context_size = []
|
||||
avg_gen_per_worker = []
|
||||
concurrent_gens = 0
|
||||
for backend_url in b:
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
if backend_info.get('model_config'):
|
||||
context_size.append(backend_info['model_config']['max_position_embeddings'])
|
||||
if backend_info.get('average_generation_elapsed_sec'):
|
||||
avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec'])
|
||||
concurrent_gens += backend_info['concurrent_gens']
|
||||
|
||||
active_gen_workers = get_active_gen_workers_model(model)
|
||||
proompters_in_queue = priority_queue.len(model)
|
||||
|
||||
if len(avg_gen_per_worker):
|
||||
average_generation_elapsed_sec = np.average(avg_gen_per_worker)
|
||||
else:
|
||||
average_generation_elapsed_sec = 0
|
||||
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, concurrent_gens, active_gen_workers)
|
||||
|
||||
model_choices[model] = {
|
||||
'model': model,
|
||||
'client_api': f'https://{base_client_api}/{model}',
|
||||
'ws_client_api': f'wss://{base_client_api}/{model}/v1/stream' if opts.enable_streaming else None,
|
||||
'openai_client_api': f'https://{base_client_api}/openai/{model}/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
'backend_count': len(b),
|
||||
'estimated_wait': estimated_wait_sec,
|
||||
'queued': proompters_in_queue,
|
||||
'processing': active_gen_workers,
|
||||
'avg_generation_time': average_generation_elapsed_sec,
|
||||
'concurrent_gens': concurrent_gens
|
||||
}
|
||||
|
||||
if len(context_size):
|
||||
model_choices[model]['context_size'] = min(context_size)
|
||||
|
||||
# Python wants to sort lowercase vs. uppercase letters differently.
|
||||
model_choices = dict(sorted(model_choices.items(), key=lambda item: item[0].upper()))
|
||||
|
||||
default_backend_url = get_a_cluster_backend()
|
||||
default_backend_info = cluster_config.get_backend(default_backend_url)
|
||||
if not default_backend_info.get('model'):
|
||||
return {}, None
|
||||
default_model = default_backend_info['model']
|
||||
|
||||
redis.setp('model_choices', (model_choices, default_model))
|
||||
return model_choices, default_model
|
|
@ -0,0 +1,124 @@
|
|||
import hashlib
|
||||
import pickle
|
||||
import traceback
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.custom_redis import RedisCustom
|
||||
from llm_server.routes.helpers.model import estimate_model_size
|
||||
|
||||
|
||||
class RedisClusterStore:
|
||||
def __init__(self, name: str, **kwargs):
|
||||
self.name = name
|
||||
self.config_redis = RedisCustom(name, **kwargs)
|
||||
|
||||
def clear(self):
|
||||
self.config_redis.flush()
|
||||
|
||||
def load(self, config: dict):
|
||||
for k, v in config.items():
|
||||
self.add_backend(k, v)
|
||||
|
||||
def add_backend(self, name: str, values: dict):
|
||||
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
|
||||
self.set_backend_value(name, 'online', False)
|
||||
h = hashlib.sha256(name.encode('utf-8')).hexdigest()
|
||||
self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}')
|
||||
|
||||
def set_backend_value(self, backend: str, key: str, value):
|
||||
# By storing the value as a pickle we don't have to cast anything when getting the value from Redis.
|
||||
self.config_redis.hset(backend, key, pickle.dumps(value))
|
||||
|
||||
def get_backend(self, name: str):
|
||||
r = self.config_redis.hgetall(name)
|
||||
output = {}
|
||||
for k, v in r.items():
|
||||
output[k.decode('utf8')] = pickle.loads(v)
|
||||
return output
|
||||
|
||||
def all(self):
|
||||
keys = self.config_redis.keys('*')
|
||||
if keys:
|
||||
result = {}
|
||||
for key in keys:
|
||||
if key != f'{self.name}:____':
|
||||
v = self.get_backend(key)
|
||||
result[key] = v
|
||||
return result
|
||||
else:
|
||||
return {}
|
||||
|
||||
def validate_backend(self, backend_url: str):
|
||||
"""
|
||||
Returns the backend URL that was given, or a new one if that was offline.
|
||||
:param backend_url:
|
||||
:return:
|
||||
"""
|
||||
backend_info = self.get_backend(backend_url)
|
||||
if not backend_info['online']:
|
||||
old = backend_url
|
||||
backend_url = get_a_cluster_backend()
|
||||
print(f'Backend {old} offline. Request was redirected to {backend_url}')
|
||||
return backend_url
|
||||
|
||||
|
||||
cluster_config = RedisClusterStore('cluster_config')
|
||||
|
||||
|
||||
def get_backends():
|
||||
backends = cluster_config.all()
|
||||
result = {}
|
||||
for k, v in backends.items():
|
||||
b = cluster_config.get_backend(k)
|
||||
status = b.get('online', False)
|
||||
priority = b['priority']
|
||||
result[k] = {'status': status, 'priority': priority}
|
||||
|
||||
try:
|
||||
if not opts.prioritize_by_size:
|
||||
online_backends = sorted(
|
||||
((url, info) for url, info in backends.items() if info['online']),
|
||||
key=lambda kv: -kv[1]['priority'],
|
||||
reverse=True
|
||||
)
|
||||
else:
|
||||
online_backends = sorted(
|
||||
((url, info) for url, info in backends.items() if info['online']),
|
||||
key=lambda kv: estimate_model_size(kv[1]['model_config']),
|
||||
reverse=True
|
||||
)
|
||||
offline_backends = sorted(
|
||||
((url, info) for url, info in backends.items() if not info['online']),
|
||||
key=lambda kv: -kv[1]['priority'],
|
||||
reverse=True
|
||||
)
|
||||
return [url for url, info in online_backends], [url for url, info in offline_backends]
|
||||
except KeyError:
|
||||
traceback.print_exc()
|
||||
print(backends)
|
||||
|
||||
|
||||
def get_a_cluster_backend(model=None):
|
||||
"""
|
||||
Get a backend from Redis. If there are no online backends, return None.
|
||||
If `model` is not supplied, we will pick one ourself.
|
||||
"""
|
||||
if model:
|
||||
# First, determine if there are multiple backends hosting the same model.
|
||||
backends_hosting_model = [i.decode('utf-8') for i in redis_running_models.smembers(model)]
|
||||
|
||||
# If so, create an iterator for those backends
|
||||
if len(backends_hosting_model):
|
||||
add_backend_cycler(model, backends_hosting_model)
|
||||
cycled = redis_cycle(model)
|
||||
if len(cycled):
|
||||
return cycled[0]
|
||||
else:
|
||||
# No backend hosting that model
|
||||
return None
|
||||
else:
|
||||
online, _ = get_backends()
|
||||
if len(online):
|
||||
return online[0]
|
|
@ -0,0 +1,39 @@
|
|||
import redis
|
||||
|
||||
redis_cycler_db = redis.Redis(host='localhost', port=6379, db=9)
|
||||
|
||||
|
||||
def redis_cycle(list_name):
|
||||
"""
|
||||
Emulates itertools.cycle() but returns the complete shuffled list.
|
||||
:param list_name:
|
||||
:return:
|
||||
"""
|
||||
pipeline = redis_cycler_db.pipeline()
|
||||
pipeline.lpop(list_name)
|
||||
to_move = pipeline.execute()[0]
|
||||
if not to_move:
|
||||
return []
|
||||
pipeline.rpush(list_name, to_move)
|
||||
pipeline.lrange(list_name, 0, -1)
|
||||
results = pipeline.execute()
|
||||
new_list = results[-1]
|
||||
return [x.decode('utf-8') for x in new_list]
|
||||
|
||||
|
||||
def add_backend_cycler(list_name: str, new_elements: list):
|
||||
existing_elements = [i.decode('utf-8') for i in redis_cycler_db.lrange(list_name, 0, -1)]
|
||||
existing_set = set(existing_elements)
|
||||
|
||||
with redis_cycler_db.pipeline() as pipe:
|
||||
# Add elements
|
||||
for element in new_elements:
|
||||
if element not in existing_set:
|
||||
pipe.rpush(list_name, element)
|
||||
|
||||
# Remove elements
|
||||
for element in existing_set:
|
||||
if element not in new_elements:
|
||||
pipe.lrem(list_name, 0, element)
|
||||
|
||||
pipe.execute()
|
|
@ -0,0 +1,3 @@
|
|||
from llm_server.custom_redis import RedisCustom
|
||||
|
||||
redis_running_models = RedisCustom('running_models')
|
|
@ -0,0 +1,38 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
from llm_server.cluster.backend import test_backend
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
|
||||
|
||||
def cluster_worker():
|
||||
counter = 0
|
||||
while True:
|
||||
test_prompt = False
|
||||
if counter % 4 == 0:
|
||||
# Only send a test prompt every 120 seconds.
|
||||
test_prompt = True
|
||||
threads = []
|
||||
for n, v in cluster_config.all().items():
|
||||
thread = Thread(target=check_backend, args=(n, v, test_prompt))
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
time.sleep(15)
|
||||
counter += 1
|
||||
|
||||
|
||||
def check_backend(n, v, test_prompt):
|
||||
online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
|
||||
if online:
|
||||
running_model = backend_info['model']
|
||||
for k, v in backend_info.items():
|
||||
cluster_config.set_backend_value(n, k, v)
|
||||
redis_running_models.sadd(running_model, n)
|
||||
else:
|
||||
for model in redis_running_models.keys():
|
||||
redis_running_models.srem(model, n)
|
||||
|
||||
cluster_config.set_backend_value(n, 'online', online)
|
|
@ -28,16 +28,19 @@ config_default_vars = {
|
|||
'openai_force_no_hashes': True,
|
||||
'include_system_tokens_in_stats': True,
|
||||
'openai_moderation_scan_last_n': 5,
|
||||
'openai_moderation_workers': 10,
|
||||
'openai_org_name': 'OpenAI',
|
||||
'openai_silent_trim': False,
|
||||
'openai_moderation_enabled': True,
|
||||
'netdata_root': None
|
||||
'netdata_root': None,
|
||||
'show_backends': True,
|
||||
'background_homepage_cacher': True,
|
||||
'openai_moderation_timeout': 5,
|
||||
'prioritize_by_size': False
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
config_required_vars = ['cluster', 'frontend_api_mode', 'llm_middleware_name']
|
||||
|
||||
mode_ui_names = {
|
||||
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
'ooba': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
'vllm': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
}
|
||||
|
||||
|
|
|
@ -3,38 +3,28 @@ import sys
|
|||
|
||||
import openai
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.config.config import ConfigLoader, config_default_vars, config_required_vars
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.database import get_number_of_rows
|
||||
from llm_server.helpers import resolve_path
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.queue import PriorityQueue
|
||||
|
||||
|
||||
def load_config(config_path, script_path):
|
||||
def load_config(config_path):
|
||||
config_loader = ConfigLoader(config_path, config_default_vars, config_required_vars)
|
||||
success, config, msg = config_loader.load_config()
|
||||
if not success:
|
||||
return success, config, msg
|
||||
|
||||
# Resolve relative directory to the directory of the script
|
||||
if config['database_path'].startswith('./'):
|
||||
config['database_path'] = resolve_path(script_path, config['database_path'].strip('./'))
|
||||
|
||||
if config['mode'] not in ['oobabooga', 'vllm']:
|
||||
print('Unknown mode:', config['mode'])
|
||||
sys.exit(1)
|
||||
|
||||
# TODO: this is atrocious
|
||||
opts.mode = config['mode']
|
||||
opts.auth_required = config['auth_required']
|
||||
opts.log_prompts = config['log_prompts']
|
||||
opts.concurrent_gens = config['concurrent_gens']
|
||||
opts.frontend_api_client = config['frontend_api_client']
|
||||
opts.context_size = config['token_limit']
|
||||
opts.show_num_prompts = config['show_num_prompts']
|
||||
opts.show_uptime = config['show_uptime']
|
||||
opts.backend_url = config['backend_url'].strip('/')
|
||||
opts.cluster = config['cluster']
|
||||
opts.show_total_output_tokens = config['show_total_output_tokens']
|
||||
opts.netdata_root = config['netdata_root']
|
||||
opts.simultaneous_requests_per_ip = config['simultaneous_requests_per_ip']
|
||||
|
@ -53,10 +43,20 @@ def load_config(config_path, script_path):
|
|||
opts.openai_force_no_hashes = config['openai_force_no_hashes']
|
||||
opts.include_system_tokens_in_stats = config['include_system_tokens_in_stats']
|
||||
opts.openai_moderation_scan_last_n = config['openai_moderation_scan_last_n']
|
||||
opts.openai_moderation_workers = config['openai_moderation_workers']
|
||||
opts.openai_org_name = config['openai_org_name']
|
||||
opts.openai_silent_trim = config['openai_silent_trim']
|
||||
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
||||
opts.show_backends = config['show_backends']
|
||||
opts.background_homepage_cacher = config['background_homepage_cacher']
|
||||
opts.openai_moderation_timeout = config['openai_moderation_timeout']
|
||||
opts.frontend_api_mode = config['frontend_api_mode']
|
||||
opts.prioritize_by_size = config['prioritize_by_size']
|
||||
|
||||
# Scale the number of workers.
|
||||
for item in config['cluster']:
|
||||
opts.cluster_workers += item['concurrent_gens']
|
||||
|
||||
llm_server.routes.queue.priority_queue = PriorityQueue([x['backend_url'] for x in config['cluster']])
|
||||
|
||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||
|
@ -78,6 +78,16 @@ def load_config(config_path, script_path):
|
|||
if config['load_num_prompts']:
|
||||
redis.set('proompts', get_number_of_rows('prompts'))
|
||||
|
||||
redis.set('backend_mode', opts.mode)
|
||||
|
||||
return success, config, msg
|
||||
|
||||
|
||||
def parse_backends(config):
|
||||
if not config.get('cluster'):
|
||||
return False
|
||||
cluster = config.get('cluster')
|
||||
config = {}
|
||||
for item in cluster:
|
||||
backend_url = item['backend_url'].strip('/')
|
||||
item['backend_url'] = backend_url
|
||||
config[backend_url] = item
|
||||
return config
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from llm_server.custom_redis import RedisCustom
|
||||
|
||||
redis_config = RedisCustom('redis_config')
|
|
@ -1,24 +1,27 @@
|
|||
import pickle
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Callable, List, Mapping, Union
|
||||
from typing import Callable, List, Mapping, Optional, Union
|
||||
|
||||
import redis as redis_pkg
|
||||
import simplejson as json
|
||||
from flask_caching import Cache
|
||||
from redis import Redis
|
||||
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT
|
||||
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
|
||||
|
||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
|
||||
ONE_MONTH_SECONDS = 2678000
|
||||
|
||||
|
||||
class RedisWrapper:
|
||||
class RedisCustom(Redis):
|
||||
"""
|
||||
A wrapper class to set prefixes to keys.
|
||||
A simple wrapper class for Redis to create a "namespace" within a DB,
|
||||
which simplyifies key management.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, **kwargs):
|
||||
super().__init__()
|
||||
self.redis = Redis(**kwargs)
|
||||
self.prefix = prefix
|
||||
try:
|
||||
|
@ -34,12 +37,11 @@ class RedisWrapper:
|
|||
def set(self, key, value, ex: Union[ExpiryT, None] = None):
|
||||
return self.redis.set(self._key(key), value, ex=ex)
|
||||
|
||||
def get(self, key, dtype=None, default=None):
|
||||
"""
|
||||
:param key:
|
||||
:param dtype: convert to this type
|
||||
:return:
|
||||
"""
|
||||
def get(self, key, default=None, dtype=None):
|
||||
# TODO: use pickle
|
||||
import inspect
|
||||
if inspect.isclass(default):
|
||||
raise Exception
|
||||
|
||||
d = self.redis.get(self._key(key))
|
||||
if dtype and d:
|
||||
|
@ -108,7 +110,10 @@ class RedisWrapper:
|
|||
):
|
||||
return self.redis.hincrby(self._key(name), key, amount)
|
||||
|
||||
def hdel(self, name: str, *keys: List):
|
||||
def zcard(self, name: KeyT):
|
||||
return self.redis.zcard(self._key(name))
|
||||
|
||||
def hdel(self, name: str, *keys: str):
|
||||
return self.redis.hdel(self._key(name), *keys)
|
||||
|
||||
def hget(
|
||||
|
@ -129,9 +134,62 @@ class RedisWrapper:
|
|||
):
|
||||
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
|
||||
|
||||
def lpush(self, name: str, *values: FieldT):
|
||||
return self.redis.lpush(self._key(name), *values)
|
||||
|
||||
def hset(
|
||||
self,
|
||||
name: str,
|
||||
key: Optional = None,
|
||||
value=None,
|
||||
mapping: Optional[dict] = None,
|
||||
items: Optional[list] = None,
|
||||
):
|
||||
return self.redis.hset(self._key(name), key, value, mapping, items)
|
||||
|
||||
def hkeys(self, name: str):
|
||||
return self.redis.hkeys(self._key(name))
|
||||
|
||||
def hmget(self, name: str, keys: List, *args: List):
|
||||
return self.redis.hmget(self._key(name), keys, *args)
|
||||
|
||||
def hgetall(self, name: str):
|
||||
return self.redis.hgetall(self._key(name))
|
||||
|
||||
def keys(self, pattern: PatternT = "*", **kwargs):
|
||||
raw_keys = self.redis.keys(self._key(pattern), **kwargs)
|
||||
keys = []
|
||||
for key in raw_keys:
|
||||
p = key.decode('utf-8').split(':')
|
||||
if len(p) >= 2:
|
||||
# Delete prefix
|
||||
del p[0]
|
||||
k = ':'.join(p)
|
||||
if k != '____':
|
||||
keys.append(k)
|
||||
return keys
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
return self.redis.pipeline(transaction, shard_hint)
|
||||
|
||||
def smembers(self, name: str):
|
||||
return self.redis.smembers(self._key(name))
|
||||
|
||||
def spop(self, name: str, count: Optional[int] = None):
|
||||
return self.redis.spop(self._key(name), count)
|
||||
|
||||
def rpoplpush(self, src, dst):
|
||||
return self.redis.rpoplpush(src, dst)
|
||||
|
||||
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
|
||||
return self.redis.zpopmin(self._key(name), count)
|
||||
|
||||
def exists(self, *names: KeyT):
|
||||
n = []
|
||||
for name in names:
|
||||
n.append(self._key(name))
|
||||
return self.redis.exists(*n)
|
||||
|
||||
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
|
||||
return self.set(key, json.dumps(dict_value), ex=ex)
|
||||
|
||||
|
@ -142,6 +200,15 @@ class RedisWrapper:
|
|||
else:
|
||||
return json.loads(r.decode("utf-8"))
|
||||
|
||||
def setp(self, name, value):
|
||||
self.redis.set(self._key(name), pickle.dumps(value))
|
||||
|
||||
def getp(self, name: str):
|
||||
r = self.redis.get(self._key(name))
|
||||
if r:
|
||||
return pickle.loads(r)
|
||||
return r
|
||||
|
||||
def flush(self):
|
||||
flushed = []
|
||||
for key in self.redis.scan_iter(f'{self.prefix}:*'):
|
||||
|
@ -149,5 +216,40 @@ class RedisWrapper:
|
|||
self.redis.delete(key)
|
||||
return flushed
|
||||
|
||||
def flushall(self, asynchronous: bool = ..., **kwargs) -> bool:
|
||||
self.flush()
|
||||
return True
|
||||
|
||||
redis = RedisWrapper('local_llm')
|
||||
def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool:
|
||||
self.flush()
|
||||
return True
|
||||
|
||||
def lrange(self, name: str, start: int, end: int):
|
||||
return self.redis.lrange(self._key(name), start, end)
|
||||
|
||||
def delete(self, *names: KeyT):
|
||||
return self.redis.delete(*[self._key(i) for i in names])
|
||||
|
||||
def lpop(self, name: str, count: Optional[int] = None):
|
||||
return self.redis.lpop(self._key(name), count)
|
||||
|
||||
def zrange(
|
||||
self,
|
||||
name: KeyT,
|
||||
start: int,
|
||||
end: int,
|
||||
desc: bool = False,
|
||||
withscores: bool = False,
|
||||
score_cast_func: Union[type, Callable] = float,
|
||||
byscore: bool = False,
|
||||
bylex: bool = False,
|
||||
offset: int = None,
|
||||
num: int = None,
|
||||
):
|
||||
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
|
||||
|
||||
def zrem(self, name: KeyT, *values: FieldT):
|
||||
return self.redis.zrem(self._key(name), *values)
|
||||
|
||||
|
||||
redis = RedisCustom('local_llm')
|
|
@ -5,20 +5,20 @@ class DatabaseConnection:
|
|||
host: str = None
|
||||
username: str = None
|
||||
password: str = None
|
||||
database: str = None
|
||||
database_name: str = None
|
||||
|
||||
def init_db(self, host, username, password, database):
|
||||
def init_db(self, host, username, password, database_name):
|
||||
self.host = host
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.database = database
|
||||
self.database_name = database_name
|
||||
|
||||
def cursor(self):
|
||||
db = pymysql.connect(
|
||||
host=self.host,
|
||||
user=self.username,
|
||||
password=self.password,
|
||||
database=self.database,
|
||||
database=self.database_name,
|
||||
charset='utf8mb4',
|
||||
autocommit=True,
|
||||
)
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
import json
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.llm.vllm import tokenize
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.llm import get_token_count
|
||||
|
||||
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
|
||||
def do_db_log(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
|
||||
assert isinstance(prompt, str)
|
||||
assert isinstance(backend_url, str)
|
||||
|
||||
# Try not to shove JSON into the database.
|
||||
if isinstance(response, dict) and response.get('results'):
|
||||
response = response['results'][0]['text']
|
||||
try:
|
||||
|
@ -19,10 +23,11 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
except:
|
||||
pass
|
||||
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
prompt_tokens = get_token_count(prompt, backend_url)
|
||||
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
response_tokens = get_token_count(response, backend_url)
|
||||
else:
|
||||
response_tokens = None
|
||||
|
||||
|
@ -43,7 +48,9 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
if token:
|
||||
increment_token_uses(token)
|
||||
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
running_model = backend_info.get('model')
|
||||
backend_mode = backend_info['mode']
|
||||
timestamp = int(time.time())
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
|
@ -52,7 +59,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
(ip, token, running_model, backend_mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
@ -179,3 +186,21 @@ def increment_token_uses(token):
|
|||
cursor.execute('UPDATE token_auth SET uses = uses + 1 WHERE token = %s', (token,))
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
||||
def get_token_ratelimit(token):
|
||||
priority = 9990
|
||||
simultaneous_ip = opts.simultaneous_requests_per_ip
|
||||
if token:
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (token,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
priority, simultaneous_ip = result
|
||||
if simultaneous_ip is None:
|
||||
# No ratelimit for this token if null
|
||||
simultaneous_ip = 999999999
|
||||
finally:
|
||||
cursor.close()
|
||||
return priority, simultaneous_ip
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
import pickle
|
||||
from typing import Union
|
||||
|
||||
from redis import Redis
|
||||
|
||||
|
||||
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
|
||||
assert isinstance(prompt, str)
|
||||
assert isinstance(backend_url, str)
|
||||
|
||||
r = Redis(host='localhost', port=6379, db=3)
|
||||
data = {
|
||||
'function': 'log_prompt',
|
||||
'args': [],
|
||||
'kwargs': {
|
||||
'ip': ip,
|
||||
'token': token,
|
||||
'prompt': prompt,
|
||||
'response': response,
|
||||
'gen_time': gen_time,
|
||||
'parameters': parameters,
|
||||
'headers': dict(headers) if headers else headers,
|
||||
'backend_response_code': backend_response_code,
|
||||
'request_url': request_url,
|
||||
'backend_url': backend_url,
|
||||
'response_tokens': response_tokens,
|
||||
'is_error': is_error
|
||||
}
|
||||
}
|
||||
r.publish('database-logger', pickle.dumps(data))
|
|
@ -8,7 +8,7 @@ import simplejson as json
|
|||
from flask import make_response
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
def resolve_path(*p: str):
|
||||
|
@ -54,13 +54,14 @@ def jsonify_pretty(json_dict: Union[list, dict], status=200, indent=4, sort_keys
|
|||
|
||||
def round_up_base(n, base):
|
||||
if base == 0:
|
||||
print('round_up_base DIVIDE BY ZERO ERROR????', n, base)
|
||||
# TODO: I don't think passing (0, 0) to this function is a sign of any underlying issues.
|
||||
# print('round_up_base DIVIDE BY ZERO ERROR????', n, base)
|
||||
return 0
|
||||
return math.ceil(n / base) * base
|
||||
|
||||
|
||||
def auto_set_base_client_api(request):
|
||||
http_host = redis.get('http_host', str)
|
||||
http_host = redis.get('http_host', dtype=str)
|
||||
host = request.headers.get("Host")
|
||||
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
|
||||
# If the current http_host is not an IP, don't do anything.
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
import threading
|
||||
|
||||
|
||||
class ThreadSafeInteger:
|
||||
def __init__(self, value=0):
|
||||
self.value = value
|
||||
self._value_lock = threading.Lock()
|
||||
|
||||
def increment(self):
|
||||
with self._value_lock:
|
||||
self.value += 1
|
||||
return self.value
|
|
@ -1,11 +1,30 @@
|
|||
import tiktoken
|
||||
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.llm import oobabooga, vllm
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.logging import create_logger
|
||||
|
||||
|
||||
def get_token_count(prompt: str):
|
||||
backend_mode = redis.get('backend_mode', str)
|
||||
def fallback_tokenizer(prompt: str):
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
return len(tokenizer.encode(prompt)) + 10
|
||||
|
||||
|
||||
def get_token_count(prompt: str, backend_url: str):
|
||||
backend_url = cluster_config.validate_backend(backend_url)
|
||||
if not backend_url:
|
||||
logger = create_logger('tokenizer')
|
||||
logger.warning('using fallback tokenizer as there is no valid backend')
|
||||
return fallback_tokenizer(prompt)
|
||||
|
||||
backend_mode = cluster_config.get_backend(backend_url).get('mode')
|
||||
if not backend_mode:
|
||||
logger = create_logger('tokenizer')
|
||||
logger.warning("using fallback tokenizer as the backend isn't initalized")
|
||||
return fallback_tokenizer(prompt)
|
||||
|
||||
if backend_mode == 'vllm':
|
||||
return vllm.tokenize(prompt)
|
||||
return vllm.tokenize(prompt, backend_url)
|
||||
elif backend_mode == 'ooba':
|
||||
return oobabooga.tokenize(prompt)
|
||||
else:
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
def generator(request_json_body):
|
||||
if opts.mode == 'oobabooga':
|
||||
def generator(request_json_body, cluster_backend, timeout: int = None):
|
||||
mode = cluster_config.get_backend(cluster_backend)['mode']
|
||||
if mode == 'ooba':
|
||||
# from .oobabooga.generate import generate
|
||||
# return generate(request_json_body)
|
||||
raise NotImplementedError
|
||||
elif opts.mode == 'vllm':
|
||||
elif mode == 'vllm':
|
||||
from .vllm.generate import generate
|
||||
r = generate(request_json_body)
|
||||
return r
|
||||
return generate(request_json_body, cluster_backend, timeout=timeout)
|
||||
else:
|
||||
raise Exception
|
||||
|
|
|
@ -3,23 +3,35 @@ import requests
|
|||
from llm_server import opts
|
||||
|
||||
|
||||
def get_running_model():
|
||||
# TODO: cache the results for 1 min so we don't have to keep calling the backend
|
||||
# TODO: only use one try/catch
|
||||
|
||||
if opts.mode == 'oobabooga':
|
||||
def get_running_model(backend_url: str, mode: str):
|
||||
if mode == 'ooba':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
|
||||
backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
return r_json['result'], None
|
||||
except Exception as e:
|
||||
return False, e
|
||||
elif opts.mode == 'vllm':
|
||||
elif mode == 'vllm':
|
||||
try:
|
||||
backend_response = requests.get(f'{opts.backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
|
||||
backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
return r_json['model'], None
|
||||
except Exception as e:
|
||||
return False, e
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
||||
def get_info(backend_url: str, mode: str):
|
||||
if mode == 'ooba':
|
||||
return {}
|
||||
# raise NotImplementedError
|
||||
elif mode == 'vllm':
|
||||
try:
|
||||
r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout)
|
||||
j = r.json()
|
||||
except Exception as e:
|
||||
return {}
|
||||
return j
|
||||
else:
|
||||
raise Exception
|
||||
|
|
|
@ -2,14 +2,17 @@ from typing import Tuple, Union
|
|||
|
||||
import flask
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
class LLMBackend:
|
||||
_default_params: dict
|
||||
|
||||
def __init__(self, backend_url: str):
|
||||
self.backend_url = backend_url
|
||||
self.backend_info = cluster_config.get_backend(self.backend_url)
|
||||
|
||||
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -32,14 +35,16 @@ class LLMBackend:
|
|||
"""
|
||||
If a backend needs to do other checks not related to the prompt or parameters.
|
||||
Default is no extra checks preformed.
|
||||
:param request:
|
||||
:param prompt:
|
||||
:param parameters:
|
||||
:return:
|
||||
"""
|
||||
return True, None
|
||||
|
||||
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = get_token_count(prompt)
|
||||
if prompt_len > opts.context_size - 10:
|
||||
model_name = redis.get('running_model', str, 'NO MODEL ERROR')
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
|
||||
prompt_len = get_token_count(prompt, self.backend_url)
|
||||
token_limit = self.backend_info['model_config']['max_position_embeddings']
|
||||
if prompt_len > token_limit - 10:
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {self.backend_info["model"]}). Please lower your context size'
|
||||
return True, None
|
||||
|
|
|
@ -1,78 +1,6 @@
|
|||
from flask import jsonify
|
||||
|
||||
from ..llm_backend import LLMBackend
|
||||
from ...database.database import log_prompt
|
||||
from ...helpers import safe_list_get
|
||||
from ...routes.cache import redis
|
||||
from ...routes.helpers.client import format_sillytavern_err
|
||||
from ...routes.helpers.http import validate_json
|
||||
|
||||
|
||||
class OobaboogaBackend(LLMBackend):
|
||||
default_params = {}
|
||||
|
||||
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
raise NotImplementedError('need to implement default_params')
|
||||
|
||||
backend_err = False
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
if response:
|
||||
try:
|
||||
# Be extra careful when getting attributes from the response object
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
else:
|
||||
response_status_code = None
|
||||
|
||||
# ===============================================
|
||||
|
||||
# We encountered an error
|
||||
if not success or not response or error_msg:
|
||||
if not error_msg or error_msg == '':
|
||||
error_msg = 'Unknown error.'
|
||||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = format_sillytavern_err(error_msg, 'error')
|
||||
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'msg': error_msg,
|
||||
'results': [{'text': backend_response}]
|
||||
}), 400
|
||||
|
||||
# ===============================================
|
||||
|
||||
if response_valid_json:
|
||||
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
|
||||
if not backend_response:
|
||||
# Ooba doesn't return any error messages so we will just tell the client an error occurred
|
||||
backend_err = True
|
||||
backend_response = format_sillytavern_err(
|
||||
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
|
||||
'error')
|
||||
response_json_body['results'][0]['text'] = backend_response
|
||||
|
||||
if not backend_err:
|
||||
redis.incr('proompts')
|
||||
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
||||
return jsonify({
|
||||
**response_json_body
|
||||
}), 200
|
||||
else:
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'msg': 'the backend did not return valid JSON',
|
||||
'results': [{'text': backend_response}]
|
||||
}), 400
|
||||
|
||||
def validate_params(self, params_dict: dict):
|
||||
# No validation required
|
||||
return True, None
|
||||
|
||||
def get_parameters(self, parameters):
|
||||
del parameters['prompt']
|
||||
return parameters
|
||||
def __int__(self):
|
||||
return
|
||||
|
|
|
@ -10,7 +10,7 @@ def check_moderation_endpoint(prompt: str):
|
|||
}
|
||||
response = requests.post('https://api.openai.com/v1/moderations', headers=headers, json={"input": prompt}, timeout=10)
|
||||
if response.status_code != 200:
|
||||
print(response.text)
|
||||
print('moderation failed:', response)
|
||||
response.raise_for_status()
|
||||
response = response.json()
|
||||
|
||||
|
|
|
@ -0,0 +1,97 @@
|
|||
from flask import jsonify
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def oai_to_vllm(request_json_body, stop_hashes: bool, mode):
|
||||
if not request_json_body.get('stop'):
|
||||
request_json_body['stop'] = []
|
||||
if not isinstance(request_json_body['stop'], list):
|
||||
# It is a string, so create a list with the existing element.
|
||||
request_json_body['stop'] = [request_json_body['stop']]
|
||||
|
||||
if stop_hashes:
|
||||
if opts.openai_force_no_hashes:
|
||||
request_json_body['stop'].append('###')
|
||||
else:
|
||||
# TODO: make stopping strings a configurable
|
||||
request_json_body['stop'].extend(['### INSTRUCTION', '### USER', '### ASSISTANT'])
|
||||
else:
|
||||
request_json_body['stop'].extend(['user:', 'assistant:'])
|
||||
|
||||
if request_json_body.get('frequency_penalty', 0) < -2:
|
||||
request_json_body['frequency_penalty'] = -2
|
||||
elif request_json_body.get('frequency_penalty', 0) > 2:
|
||||
request_json_body['frequency_penalty'] = 2
|
||||
|
||||
if mode == 'vllm' and request_json_body.get('top_p') == 0:
|
||||
request_json_body['top_p'] = 0.01
|
||||
|
||||
request_json_body['max_tokens'] = min(max(request_json_body.get('max_new_tokens', 0), request_json_body.get('max_tokens', 0)), opts.max_new_tokens)
|
||||
if request_json_body['max_tokens'] == 0:
|
||||
# We don't want to set any defaults here.
|
||||
del request_json_body['max_tokens']
|
||||
|
||||
return request_json_body
|
||||
|
||||
|
||||
def format_oai_err(err_msg):
|
||||
print('OAI ERROR MESSAGE:', err_msg)
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": err_msg,
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None
|
||||
}
|
||||
}), 400
|
||||
|
||||
|
||||
def validate_oai(parameters):
|
||||
if parameters.get('messages'):
|
||||
for m in parameters['messages']:
|
||||
if m['role'].lower() not in ['assistant', 'user', 'system']:
|
||||
return format_oai_err('messages role must be assistant, user, or system')
|
||||
|
||||
if parameters.get('temperature', 0) > 2:
|
||||
return format_oai_err(f"{parameters['temperature']} is greater than the maximum of 2 - 'temperature'")
|
||||
if parameters.get('temperature', 0) < 0:
|
||||
return format_oai_err(f"{parameters['temperature']} less than the minimum of 0 - 'temperature'")
|
||||
|
||||
if parameters.get('top_p', 1) > 2:
|
||||
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||||
if parameters.get('top_p', 1) < 0:
|
||||
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
||||
|
||||
if parameters.get('presence_penalty', 1) > 2:
|
||||
return format_oai_err(f"{parameters['presence_penalty']} is greater than the maximum of 2 - 'presence_penalty'")
|
||||
if parameters.get('presence_penalty', 1) < -2:
|
||||
return format_oai_err(f"{parameters['presence_penalty']} less than the minimum of -2 - 'presence_penalty'")
|
||||
|
||||
if parameters.get('top_p', 1) > 2:
|
||||
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||||
if parameters.get('top_p', 1) < 0:
|
||||
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
||||
|
||||
if parameters.get('top_p', 1) > 2:
|
||||
return format_oai_err(f"{parameters['top_p']} is greater than the maximum of 1 - 'top_p'")
|
||||
if parameters.get('top_p', 1) < 0:
|
||||
return format_oai_err(f"{parameters['top_p']} less than the minimum of 0 - 'top_p'")
|
||||
|
||||
if parameters.get('max_tokens', 2) < 1:
|
||||
return format_oai_err(f"{parameters['max_tokens']} is less than the minimum of 1 - 'max_tokens'")
|
||||
|
||||
|
||||
def return_invalid_model_err(requested_model: str):
|
||||
if requested_model:
|
||||
msg = f"The model `{requested_model}` does not exist"
|
||||
else:
|
||||
msg = "The requested model does not exist"
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": msg,
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": "model_not_found"
|
||||
}
|
||||
}), 404
|
|
@ -2,86 +2,35 @@ import concurrent.futures
|
|||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, List
|
||||
|
||||
import tiktoken
|
||||
from flask import jsonify, make_response
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
||||
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
||||
|
||||
|
||||
def build_openai_response(prompt, response, model=None):
|
||||
# Seperate the user's prompt from the context
|
||||
x = prompt.split('### USER:')
|
||||
if len(x) > 1:
|
||||
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||||
|
||||
# Make sure the bot doesn't put any other instructions in its response
|
||||
# y = response.split('\n### ')
|
||||
# if len(y) > 1:
|
||||
# response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
||||
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
|
||||
response = make_response(jsonify({
|
||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
}), 200)
|
||||
|
||||
stats = redis.get('proxy_stats', dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response
|
||||
|
||||
|
||||
def generate_oai_string(length=24):
|
||||
alphabet = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(alphabet) for i in range(length))
|
||||
|
||||
|
||||
def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -> List[Dict[str, str]]:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
def get_token_count_tiktoken_thread(msg):
|
||||
return len(tokenizer.encode(msg["content"]))
|
||||
def trim_messages_to_fit(prompt: List[Dict[str, str]], context_token_limit: int, backend_url: str) -> List[Dict[str, str]]:
|
||||
def get_token_count_thread(msg):
|
||||
return get_token_count(msg["content"], backend_url)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
token_counts = list(executor.map(get_token_count_tiktoken_thread, prompt))
|
||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||
|
||||
total_tokens = sum(token_counts)
|
||||
formatting_tokens = len(tokenizer.encode(transform_messages_to_prompt(prompt))) - total_tokens
|
||||
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
||||
|
||||
# If total tokens exceed the limit, start trimming
|
||||
if total_tokens > context_token_limit:
|
||||
if total_tokens + formatting_tokens > context_token_limit:
|
||||
while True:
|
||||
while total_tokens + formatting_tokens > context_token_limit:
|
||||
# Calculate the index to start removing messages from
|
||||
|
@ -94,22 +43,43 @@ def trim_prompt_to_fit(prompt: List[Dict[str, str]], context_token_limit: int) -
|
|||
if total_tokens + formatting_tokens <= context_token_limit or remove_index == len(prompt):
|
||||
break
|
||||
|
||||
def get_token_count_thread(msg):
|
||||
return get_token_count(msg["content"])
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||
|
||||
total_tokens = sum(token_counts)
|
||||
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt)) - total_tokens
|
||||
|
||||
formatting_tokens = get_token_count(transform_messages_to_prompt(prompt), backend_url) - total_tokens
|
||||
if total_tokens + formatting_tokens > context_token_limit:
|
||||
# Start over, but this time calculate the token count using the backend
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
token_counts = list(executor.map(get_token_count_thread, prompt))
|
||||
else:
|
||||
break
|
||||
return prompt
|
||||
|
||||
|
||||
def trim_string_to_fit(prompt: str, context_token_limit: int, backend_url: str) -> str:
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
token_count = get_token_count(prompt, backend_url)
|
||||
|
||||
# If total tokens exceed the limit, start trimming
|
||||
if token_count > context_token_limit:
|
||||
while True:
|
||||
while token_count > context_token_limit:
|
||||
# Calculate the index to start removing characters from
|
||||
remove_index = len(prompt) // 3
|
||||
|
||||
while remove_index < len(prompt):
|
||||
prompt = prompt[:remove_index] + prompt[remove_index + 100:]
|
||||
token_count = len(tokenizer.encode(prompt))
|
||||
if token_count <= context_token_limit or remove_index == len(prompt):
|
||||
break
|
||||
|
||||
token_count = get_token_count(prompt, backend_url)
|
||||
if token_count > context_token_limit:
|
||||
# Start over, but this time calculate the token count using the backend
|
||||
token_count = get_token_count(prompt, backend_url)
|
||||
else:
|
||||
break
|
||||
return prompt
|
||||
|
||||
|
||||
|
@ -117,8 +87,9 @@ def transform_messages_to_prompt(oai_messages):
|
|||
try:
|
||||
prompt = f'### INSTRUCTION: {opts.openai_system_prompt}'
|
||||
for msg in oai_messages:
|
||||
if not msg.get('content') or not msg.get('role'):
|
||||
if 'content' not in msg.keys() or 'role' not in msg.keys():
|
||||
return False
|
||||
msg['content'] = str(msg['content']) # Prevent any weird issues.
|
||||
if msg['role'] == 'system':
|
||||
prompt += f'### INSTRUCTION: {msg["content"]}\n\n'
|
||||
elif msg['role'] == 'user':
|
||||
|
@ -126,7 +97,7 @@ def transform_messages_to_prompt(oai_messages):
|
|||
elif msg['role'] == 'assistant':
|
||||
prompt += f'### ASSISTANT: {msg["content"]}\n\n'
|
||||
else:
|
||||
return False
|
||||
raise Exception(f'Unknown role: {msg["role"]}')
|
||||
except Exception as e:
|
||||
# TODO: use logging
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -1,80 +1,21 @@
|
|||
"""
|
||||
This file is used by the worker that processes requests.
|
||||
"""
|
||||
import json
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
# TODO: make the VLMM backend return TPS and time elapsed
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/api_server.py
|
||||
|
||||
def prepare_json(json_data: dict):
|
||||
# logit_bias is not currently supported
|
||||
# del json_data['logit_bias']
|
||||
|
||||
# Convert back to VLLM.
|
||||
json_data['max_tokens'] = json_data.pop('max_new_tokens')
|
||||
return json_data
|
||||
|
||||
|
||||
def transform_to_text(json_request, api_response):
|
||||
"""
|
||||
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
|
||||
:param json_request:
|
||||
:param api_response:
|
||||
:return:
|
||||
"""
|
||||
prompt = transform_prompt_to_text(json_request['messages'])
|
||||
text = ''
|
||||
finish_reason = None
|
||||
for line in api_response.split('\n'):
|
||||
if line.startswith('data:'):
|
||||
try:
|
||||
data = json.loads(line[5:].strip())
|
||||
except json.decoder.JSONDecodeError:
|
||||
break
|
||||
if 'choices' in data:
|
||||
for choice in data['choices']:
|
||||
if 'delta' in choice and 'content' in choice['delta']:
|
||||
text += choice['delta']['content']
|
||||
if data['choices'][0]['finish_reason']:
|
||||
finish_reason = data['choices'][0]['finish_reason']
|
||||
|
||||
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
|
||||
completion_tokens = len(llm_server.llm.get_token_count(text))
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
|
||||
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
|
||||
return {
|
||||
"id": str(uuid4()),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
},
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
"index": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def transform_prompt_to_text(prompt: list):
|
||||
text = ''
|
||||
for item in prompt:
|
||||
|
@ -82,26 +23,26 @@ def transform_prompt_to_text(prompt: list):
|
|||
return text.strip('\n')
|
||||
|
||||
|
||||
def handle_blocking_request(json_data: dict):
|
||||
def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
|
||||
# print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
|
||||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
return False, None, 'Request to backend encountered error'
|
||||
if r.status_code != 200:
|
||||
print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
|
||||
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
def generate(json_data: dict, cluster_backend, timeout: int = None):
|
||||
if json_data.get('stream'):
|
||||
try:
|
||||
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
|
||||
except Exception as e:
|
||||
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
return False
|
||||
else:
|
||||
return handle_blocking_request(json_data)
|
||||
return handle_blocking_request(json_data, cluster_backend, timeout=timeout)
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/vllm-project/vllm" target="_blank">vllm</a> and not all Oobabooga parameters are supported.</p>
|
||||
<strong>Supported Parameters:</strong>
|
||||
<ul>
|
||||
|
@ -7,4 +11,4 @@ vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="
|
|||
<li><kbd>max_new_tokens</kbd></li>
|
||||
<li><kbd>num_beams</kbd> <span style="font-size:9pt">(setting to greater than 1 enables beam search)</span></li>
|
||||
<li><kbd>ban_eos_token</kbd></li>
|
||||
</ul>"""
|
||||
</ul>"""
|
||||
|
|
|
@ -1,26 +1,51 @@
|
|||
import concurrent.futures
|
||||
|
||||
import requests
|
||||
import tiktoken
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.logging import create_logger
|
||||
|
||||
|
||||
def tokenize(prompt: str) -> int:
|
||||
def tokenize(prompt: str, backend_url: str) -> int:
|
||||
assert backend_url
|
||||
assert isinstance(backend_url, str)
|
||||
|
||||
if not prompt:
|
||||
# The tokenizers have issues when the prompt is None.
|
||||
return 0
|
||||
assert isinstance(prompt, str)
|
||||
|
||||
logger = create_logger('tokenizer')
|
||||
|
||||
# The backend could have died between when the request was
|
||||
# submitted and now, so let's double check it's still online.
|
||||
backend_url = cluster_config.validate_backend(backend_url)
|
||||
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# First we tokenize it locally to determine if it's worth sending it to the backend.
|
||||
initial_estimate = len(tokenizer.encode(prompt))
|
||||
if initial_estimate <= opts.context_size + 200:
|
||||
# Split the prompt into 2000 character chunks
|
||||
chunk_size = 2000
|
||||
chunks = [prompt[i:i + chunk_size] for i in range(0, len(prompt), chunk_size)]
|
||||
|
||||
# Define a function to send a chunk to the server
|
||||
def send_chunk(chunk):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
r = requests.post(f'{backend_url}/tokenize', json={'input': chunk}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
j = r.json()
|
||||
return j['length']
|
||||
except Exception as e:
|
||||
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
|
||||
return len(tokenizer.encode(prompt)) + 10
|
||||
else:
|
||||
# If the result was greater than our context size, return the estimate.
|
||||
# We won't be sending it through the backend so it does't need to be accurage.
|
||||
return initial_estimate
|
||||
logger.debug(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
|
||||
return len(tokenizer.encode(chunk)) + 10
|
||||
|
||||
# Use a ThreadPoolExecutor to send all chunks to the server at once
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_chunk = {executor.submit(send_chunk, chunk): chunk for chunk in chunks}
|
||||
for future in concurrent.futures.as_completed(future_to_chunk):
|
||||
chunk = future_to_chunk[future]
|
||||
try:
|
||||
data = future.result()
|
||||
except Exception as exc:
|
||||
logger.warning('%r generated an exception: %s' % (chunk, exc))
|
||||
return sum(future.result() for future in future_to_chunk)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import threading
|
||||
from typing import Tuple
|
||||
|
||||
from flask import jsonify
|
||||
from vllm import SamplingParams
|
||||
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.llm.llm_backend import LLMBackend
|
||||
|
||||
|
||||
|
@ -19,16 +18,8 @@ class VLLMBackend(LLMBackend):
|
|||
# Failsafe
|
||||
backend_response = ''
|
||||
|
||||
r_url = request.url
|
||||
|
||||
def background_task():
|
||||
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
|
||||
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = threading.Thread(target=background_task)
|
||||
thread.start()
|
||||
thread.join()
|
||||
log_to_db(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
|
||||
response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
|
||||
|
||||
return jsonify({'results': [{'text': backend_response}]}), 200
|
||||
|
||||
|
@ -38,14 +29,20 @@ class VLLMBackend(LLMBackend):
|
|||
top_k = parameters.get('top_k', self._default_params['top_k'])
|
||||
if top_k <= 0:
|
||||
top_k = -1
|
||||
|
||||
# TODO: support more params
|
||||
sampling_params = SamplingParams(
|
||||
temperature=parameters.get('temperature', self._default_params['temperature']),
|
||||
top_p=parameters.get('top_p', self._default_params['top_p']),
|
||||
top_k=top_k,
|
||||
use_beam_search=True if parameters.get('num_beams', 0) > 1 else False,
|
||||
stop=parameters.get('stopping_strings', self._default_params['stop']),
|
||||
stop=list(set(parameters.get('stopping_strings') or parameters.get('stop', self._default_params['stop']))),
|
||||
ignore_eos=parameters.get('ban_eos_token', False),
|
||||
max_tokens=parameters.get('max_new_tokens', self._default_params['max_tokens'])
|
||||
max_tokens=parameters.get('max_new_tokens') or parameters.get('max_tokens', self._default_params['max_tokens']),
|
||||
presence_penalty=parameters.get('presence_penalty', self._default_params['presence_penalty']),
|
||||
frequency_penalty=parameters.get('frequency_penalty', self._default_params['frequency_penalty']),
|
||||
length_penalty=parameters.get('length_penalty', self._default_params['length_penalty']),
|
||||
early_stopping=parameters.get('early_stopping', self._default_params['early_stopping'])
|
||||
)
|
||||
except ValueError as e:
|
||||
return None, str(e).strip('.')
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
import logging
|
||||
|
||||
import coloredlogs
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
class LoggingInfo:
|
||||
def __init__(self):
|
||||
self._level = logging.INFO
|
||||
self._format = opts.LOGGING_FORMAT
|
||||
|
||||
@property
|
||||
def level(self):
|
||||
return self._level
|
||||
|
||||
@level.setter
|
||||
def level(self, value):
|
||||
self._level = value
|
||||
|
||||
@property
|
||||
def format(self):
|
||||
return self._format
|
||||
|
||||
@format.setter
|
||||
def format(self, value):
|
||||
self._format = value
|
||||
|
||||
|
||||
logging_info = LoggingInfo()
|
||||
|
||||
|
||||
def init_logging():
|
||||
"""
|
||||
Set up the parent logger.
|
||||
:return:
|
||||
"""
|
||||
logger = logging.getLogger('llm_server')
|
||||
logger.setLevel(logging_info.level)
|
||||
|
||||
|
||||
def create_logger(name):
|
||||
logger = logging.getLogger('llm_server').getChild(name)
|
||||
logger.setLevel(logging_info.level)
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging_info.level)
|
||||
formatter = logging.Formatter(logging_info.format)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
coloredlogs.install(logger=logger, level=logging_info.level)
|
||||
return logger
|
|
@ -0,0 +1 @@
|
|||
BACKEND_OFFLINE = 'The model you requested is not a valid choice. Please retry your query.'
|
|
@ -1,52 +0,0 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def get_power_states():
|
||||
gpu_num = 0
|
||||
output = {}
|
||||
while True:
|
||||
url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state"
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
if response.status_code != 200:
|
||||
break
|
||||
data = json.loads(response.text)
|
||||
power_state_data = data['data'][0]
|
||||
power_state = None
|
||||
for i in range(1, len(power_state_data)):
|
||||
if power_state_data[i] == 1:
|
||||
power_state = data['labels'][i]
|
||||
break
|
||||
output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p'))
|
||||
except Exception as e:
|
||||
print('Failed to fetch Netdata metrics:', e)
|
||||
return output
|
||||
gpu_num += 1
|
||||
return output
|
||||
|
||||
|
||||
def get_gpu_wh(gpu_id: int):
|
||||
chart_name = f"nvidia_smi.gpu{gpu_id}_power"
|
||||
now = datetime.now()
|
||||
one_hour_ago = now - timedelta(hours=1)
|
||||
num_seconds = int((now - one_hour_ago).total_seconds())
|
||||
params = {
|
||||
"chart": chart_name,
|
||||
"after": int(one_hour_ago.timestamp()),
|
||||
"before": int(now.timestamp()),
|
||||
"points": num_seconds,
|
||||
"group": "second",
|
||||
"format": "json",
|
||||
"options": "absolute|jsonwrap"
|
||||
}
|
||||
response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10)
|
||||
data = json.loads(response.text)
|
||||
total_power_usage_watts = sum(point[1] for point in data['result']['data'])
|
||||
# total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1)
|
||||
total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3)
|
||||
return total_power_usage_kwh
|
|
@ -1,12 +1,11 @@
|
|||
# Read-only global variables
|
||||
|
||||
# Uppercase variables are read-only globals.
|
||||
# Lowercase variables are ones that are set on startup and are never changed.
|
||||
|
||||
# TODO: rewrite the config system so I don't have to add every single config default here
|
||||
|
||||
running_model = 'ERROR'
|
||||
concurrent_gens = 3
|
||||
mode = 'oobabooga'
|
||||
backend_url = None
|
||||
context_size = 5555
|
||||
frontend_api_mode = 'ooba'
|
||||
max_new_tokens = 500
|
||||
auth_required = False
|
||||
log_prompts = False
|
||||
|
@ -33,7 +32,15 @@ openai_expose_our_model = False
|
|||
openai_force_no_hashes = True
|
||||
include_system_tokens_in_stats = True
|
||||
openai_moderation_scan_last_n = 5
|
||||
openai_moderation_workers = 10
|
||||
openai_org_name = 'OpenAI'
|
||||
openai_silent_trim = False
|
||||
openai_moderation_enabled = True
|
||||
cluster = {}
|
||||
show_backends = True
|
||||
background_homepage_cacher = True
|
||||
openai_moderation_timeout = 5
|
||||
prioritize_by_size = False
|
||||
cluster_workers = 0
|
||||
redis_stream_timeout = 25000
|
||||
|
||||
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
|
||||
|
|
|
@ -1,21 +1,9 @@
|
|||
import sys
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
def server_startup(s):
|
||||
if not redis.get('daemon_started', bool):
|
||||
if not redis.get('daemon_started', dtype=bool):
|
||||
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||
sys.exit(1)
|
||||
|
||||
# Flush the RedisPriorityQueue database.
|
||||
queue_redis = Redis(host='localhost', port=6379, db=15)
|
||||
for key in queue_redis.scan_iter('*'):
|
||||
queue_redis.delete(key)
|
||||
|
||||
# Cache the initial stats
|
||||
print('Loading backend stats...')
|
||||
generate_stats()
|
||||
|
|
|
@ -1,11 +1,18 @@
|
|||
from llm_server import opts
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
def format_sillytavern_err(msg: str, level: str = 'info'):
|
||||
http_host = redis.get('http_host', str)
|
||||
def format_sillytavern_err(msg: str, backend_url: str = None, error_type: str = 'info'):
|
||||
if backend_url:
|
||||
cluster_backend_hash = cluster_config.get_backend(backend_url)['hash']
|
||||
else:
|
||||
cluster_backend_hash = 'none'
|
||||
http_host = redis.get('http_host', dtype=str)
|
||||
return f"""```
|
||||
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
|
||||
-> {level.upper()} <-
|
||||
-> {error_type.upper()} <-
|
||||
{msg}
|
||||
```
|
||||
```
|
||||
BACKEND: {cluster_backend_hash}
|
||||
```"""
|
||||
|
|
|
@ -100,4 +100,4 @@ def validate_json(data: Union[str, flask.Request, requests.models.Response, flas
|
|||
j = json.loads(str(data))
|
||||
return True, j
|
||||
except Exception as e:
|
||||
return False, e
|
||||
return False, e
|
|
@ -0,0 +1,15 @@
|
|||
def estimate_model_size(config: dict):
|
||||
"""
|
||||
Estimate the size of a model from its config. No idea if this is correct,
|
||||
but it allows us to compare models.
|
||||
:param config:
|
||||
:return:
|
||||
"""
|
||||
vocab_size = config.get('vocab_size')
|
||||
hidden_size = config.get('hidden_size')
|
||||
num_hidden_layers = config.get('num_hidden_layers')
|
||||
intermediate_size = config.get('intermediate_size')
|
||||
if vocab_size and hidden_size and num_hidden_layers and intermediate_size:
|
||||
total_params = (vocab_size * hidden_size) + (num_hidden_layers * ((hidden_size * intermediate_size * 4) + (hidden_size * hidden_size * 3)))
|
||||
return int(total_params / 1e9)
|
||||
return 0
|
|
@ -3,8 +3,8 @@ from typing import Tuple
|
|||
import flask
|
||||
from flask import jsonify, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server import messages, opts
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.routes.helpers.client import format_sillytavern_err
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
|
||||
|
@ -13,8 +13,11 @@ class OobaRequestHandler(RequestHandler):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def handle_request(self):
|
||||
def handle_request(self, return_ok: bool = True):
|
||||
assert not self.used
|
||||
if self.offline:
|
||||
print('This backend is offline:', messages.BACKEND_OFFLINE)
|
||||
return self.handle_error(messages.BACKEND_OFFLINE)
|
||||
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
|
@ -25,14 +28,19 @@ class OobaRequestHandler(RequestHandler):
|
|||
llm_request = {**self.parameters, 'prompt': prompt}
|
||||
|
||||
_, backend_response = self.generate_response(llm_request)
|
||||
return backend_response
|
||||
if return_ok:
|
||||
# Always return 200 so ST displays our error messages
|
||||
return backend_response[0], 200
|
||||
else:
|
||||
# The OpenAI route needs to detect 429 errors.
|
||||
return backend_response
|
||||
|
||||
def handle_ratelimited(self, do_log: bool = True):
|
||||
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||
backend_response = self.handle_error(msg)
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
|
||||
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
|
||||
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||
return backend_response[0], 429
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
|
@ -40,7 +48,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
# TODO: how to format this
|
||||
response_msg = error_msg
|
||||
else:
|
||||
response_msg = format_sillytavern_err(error_msg, error_type)
|
||||
response_msg = format_sillytavern_err(error_msg, error_type=error_type, backend_url=self.backend_url)
|
||||
|
||||
return jsonify({
|
||||
'results': [{'text': response_msg}]
|
||||
|
|
|
@ -5,9 +5,11 @@ from ..server_error import handle_server_error
|
|||
from ... import opts
|
||||
|
||||
openai_bp = Blueprint('openai/v1/', __name__)
|
||||
openai_model_bp = Blueprint('openai/', __name__)
|
||||
|
||||
|
||||
@openai_bp.before_request
|
||||
@openai_model_bp.before_request
|
||||
def before_oai_request():
|
||||
if not opts.enable_openi_compatible_backend:
|
||||
return 'The OpenAI-compatible backend is disabled.', 401
|
||||
|
@ -15,8 +17,22 @@ def before_oai_request():
|
|||
|
||||
|
||||
@openai_bp.errorhandler(500)
|
||||
@openai_model_bp.errorhandler(500)
|
||||
def handle_error(e):
|
||||
return handle_server_error(e)
|
||||
"""
|
||||
Found Codes:
|
||||
"auth_subrequest_error"
|
||||
"""
|
||||
|
||||
print('OAI returning error:', e)
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Internal server error",
|
||||
"type": "auth_subrequest_error",
|
||||
"param": None,
|
||||
"code": "internal_error"
|
||||
}
|
||||
}), 500
|
||||
|
||||
|
||||
from .models import openai_list_models
|
||||
|
|
|
@ -1,113 +1,175 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import ujson
|
||||
from flask import Response, jsonify, request
|
||||
from redis import Redis
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import redis
|
||||
from llm_server.custom_redis import redis
|
||||
from . import openai_bp, openai_model_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
|
||||
from ...llm.vllm import tokenize
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
|
||||
|
||||
# TODO: add rate-limit headers?
|
||||
|
||||
|
||||
@openai_bp.route('/chat/completions', methods=['POST'])
|
||||
def openai_chat_completions():
|
||||
@openai_model_bp.route('/<model_name>/v1/chat/completions', methods=['POST'])
|
||||
def openai_chat_completions(model_name=None):
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('messages') or not request_json_body.get('model'):
|
||||
return jsonify({'code': 400, 'msg': 'invalid JSON'}), 400
|
||||
else:
|
||||
handler = OpenAIRequestHandler(request, request_json_body)
|
||||
if request_json_body.get('stream'):
|
||||
if not opts.enable_streaming:
|
||||
# TODO: return a proper OAI error message
|
||||
return 'disabled', 401
|
||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||
if handler.offline:
|
||||
return return_invalid_model_err(model_name)
|
||||
|
||||
if opts.mode != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
# TODO: simulate OAI here
|
||||
raise Exception('TODO: simulate OAI here')
|
||||
else:
|
||||
handler.prompt = transform_messages_to_prompt(request_json_body['messages'])
|
||||
msg_to_backend = {
|
||||
**handler.parameters,
|
||||
'prompt': handler.prompt,
|
||||
'stream': True,
|
||||
}
|
||||
try:
|
||||
response = generator(msg_to_backend)
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
def generate():
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
data = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
def background_task():
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = threading.Thread(target=background_task)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except:
|
||||
# TODO: simulate OAI here
|
||||
raise Exception
|
||||
else:
|
||||
if not request_json_body.get('stream'):
|
||||
try:
|
||||
return handler.handle_request()
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'Internal server error', 500
|
||||
else:
|
||||
if not opts.enable_streaming:
|
||||
return 'Streaming disabled', 403
|
||||
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
handler.parameters, e = handler.get_parameters()
|
||||
handler.request_json_body = {
|
||||
'messages': handler.request_json_body['messages'],
|
||||
'model': handler.request_json_body['model'],
|
||||
**handler.parameters
|
||||
}
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
|
||||
else:
|
||||
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
|
||||
if not handler.prompt:
|
||||
# Prevent issues on the backend.
|
||||
return 'Invalid prompt', 400
|
||||
|
||||
# Need to set the prompt in the JSON body since that's what the inference worker expects.
|
||||
handler.request_json_body['prompt'] = handler.prompt
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||
if not event:
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
None,
|
||||
None,
|
||||
handler.parameters,
|
||||
request.headers,
|
||||
429,
|
||||
request.url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
try:
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
# Need to do this before we enter generate() since we want to be able to
|
||||
# return a 408 if necessary.
|
||||
_, stream_name, error_msg = event.wait()
|
||||
if error_msg:
|
||||
print('OAI failed to start streaming:', error_msg)
|
||||
stream_name = None # set to null so that the Finally ignores it.
|
||||
return 'Request Timeout', 408
|
||||
|
||||
def generate():
|
||||
stream_redis = Redis(db=8)
|
||||
generated_text = ''
|
||||
try:
|
||||
last_id = '0-0'
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
||||
if not stream_data:
|
||||
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
||||
yield 'data: [DONE]\n\n'
|
||||
else:
|
||||
for stream_index, item in stream_data[0][1]:
|
||||
last_id = stream_index
|
||||
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
||||
data = ujson.loads(item[b'data'])
|
||||
if data['error']:
|
||||
# Not printing error since we can just check the daemon log.
|
||||
print('OAI streaming encountered error')
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
elif data['new']:
|
||||
response = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": timestamp,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": data['new']
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
generated_text = generated_text + data['new']
|
||||
yield f'data: {json.dumps(response)}\n\n'
|
||||
elif data['completed']:
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
200,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return
|
||||
except GeneratorExit:
|
||||
return
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
yield 'data: [DONE]\n\n'
|
||||
finally:
|
||||
if event:
|
||||
redis.publish(f'notifications:{event.event_id}', 'canceled')
|
||||
if stream_name:
|
||||
stream_redis.delete(stream_name)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'INTERNAL SERVER', 500
|
||||
|
|
|
@ -1,38 +1,68 @@
|
|||
import time
|
||||
import traceback
|
||||
|
||||
from flask import jsonify, make_response, request
|
||||
import simplejson as json
|
||||
import ujson
|
||||
from flask import Response, jsonify, request
|
||||
from redis import Redis
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import redis
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from llm_server.custom_redis import redis
|
||||
from . import openai_bp, openai_model_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm import get_token_count
|
||||
from ...llm.openai.transform import build_openai_response, generate_oai_string
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
||||
|
||||
|
||||
# TODO: add rate-limit headers?
|
||||
|
||||
@openai_bp.route('/completions', methods=['POST'])
|
||||
def openai_completions():
|
||||
@openai_model_bp.route('/<model_name>/v1/completions', methods=['POST'])
|
||||
def openai_completions(model_name=None):
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
try:
|
||||
response, status_code = OobaRequestHandler(request).handle_request()
|
||||
if status_code != 200:
|
||||
return status_code
|
||||
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||
if handler.offline:
|
||||
return return_invalid_model_err(model_name)
|
||||
|
||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
else:
|
||||
# The handle_request() call below will load the prompt so we don't have
|
||||
# to do anything else here.
|
||||
pass
|
||||
|
||||
handler.request_json_body['prompt'] = handler.prompt
|
||||
|
||||
if not request_json_body.get('stream'):
|
||||
invalid_oai_err_msg = validate_oai(request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
response, status_code = handler.handle_request(return_ok=False)
|
||||
if status_code == 429:
|
||||
return handler.handle_ratelimited()
|
||||
output = response.json['results'][0]['text']
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = get_token_count(request_json_body['prompt'])
|
||||
response_tokens = get_token_count(output)
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
|
||||
response_tokens = get_token_count(output, handler.backend_url)
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
response = make_response(jsonify({
|
||||
response = jsonify({
|
||||
"id": f"cmpl-{generate_oai_string(30)}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
|
@ -42,7 +72,7 @@ def openai_completions():
|
|||
"text": output,
|
||||
"index": 0,
|
||||
"logprobs": None,
|
||||
"finish_reason": None
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
|
@ -50,12 +80,141 @@ def openai_completions():
|
|||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
}), 200)
|
||||
})
|
||||
|
||||
stats = redis.get('proxy_stats', dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'Internal Server Error', 500
|
||||
# TODO:
|
||||
# stats = redis.get('proxy_stats', dtype=dict)
|
||||
# if stats:
|
||||
# response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response, 200
|
||||
else:
|
||||
if not opts.enable_streaming:
|
||||
return 'Streaming disabled', 403
|
||||
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
|
||||
handler.parameters, _ = handler.get_parameters()
|
||||
handler.request_json_body = {
|
||||
'prompt': handler.request_json_body['prompt'],
|
||||
'model': handler.request_json_body['model'],
|
||||
**handler.parameters
|
||||
}
|
||||
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
|
||||
if not handler.prompt:
|
||||
# Prevent issues on the backend.
|
||||
return 'Invalid prompt', 400
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||
if not event:
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
None,
|
||||
None,
|
||||
handler.parameters,
|
||||
request.headers,
|
||||
429,
|
||||
request.url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
try:
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
_, stream_name, error_msg = event.wait()
|
||||
if error_msg:
|
||||
print('OAI failed to start streaming:', error_msg)
|
||||
stream_name = None
|
||||
return 'Request Timeout', 408
|
||||
|
||||
def generate():
|
||||
stream_redis = Redis(db=8)
|
||||
generated_text = ''
|
||||
try:
|
||||
last_id = '0-0'
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
||||
if not stream_data:
|
||||
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
||||
yield 'data: [DONE]\n\n'
|
||||
else:
|
||||
for stream_index, item in stream_data[0][1]:
|
||||
last_id = stream_index
|
||||
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
||||
data = ujson.loads(item[b'data'])
|
||||
if data['error']:
|
||||
print('OAI streaming encountered error')
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
elif data['new']:
|
||||
response = {
|
||||
"id": f"cmpl-{oai_string}",
|
||||
"object": "text_completion",
|
||||
"created": timestamp,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": data['new']
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
generated_text = generated_text + data['new']
|
||||
yield f'data: {json.dumps(response)}\n\n'
|
||||
elif data['completed']:
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
200,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return
|
||||
except GeneratorExit:
|
||||
# This should be triggered if a client disconnects early.
|
||||
return
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
yield 'data: [DONE]\n\n'
|
||||
finally:
|
||||
if event:
|
||||
redis.publish(f'notifications:{event.event_id}', 'canceled')
|
||||
if stream_name:
|
||||
stream_redis.delete(stream_name)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'INTERNAL SERVER', 500
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from flask import Response
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import flask_cache
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from ... import opts
|
||||
|
||||
|
||||
|
|
|
@ -3,59 +3,58 @@ import traceback
|
|||
import requests
|
||||
from flask import jsonify
|
||||
|
||||
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
|
||||
from . import openai_bp
|
||||
from ..cache import ONE_MONTH_SECONDS, flask_cache, redis
|
||||
from ..stats import server_start_time
|
||||
from ... import opts
|
||||
from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
|
||||
from ...helpers import jsonify_pretty
|
||||
from ...llm.info import get_running_model
|
||||
from ...llm.openai.transform import generate_oai_string
|
||||
|
||||
|
||||
@openai_bp.route('/models', methods=['GET'])
|
||||
@flask_cache.cached(timeout=60, query_string=True)
|
||||
def openai_list_models():
|
||||
model, error = get_running_model()
|
||||
if not model:
|
||||
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
|
||||
if not model_name:
|
||||
response = jsonify({
|
||||
'code': 502,
|
||||
'msg': 'failed to reach backend',
|
||||
'type': error.__class__.__name__
|
||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
else:
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
oai = fetch_openai_models()
|
||||
r = []
|
||||
r = {
|
||||
"object": "list",
|
||||
"data": oai
|
||||
}
|
||||
# TODO: verify this works
|
||||
if opts.openai_expose_our_model:
|
||||
r = [{
|
||||
"object": "list",
|
||||
"data": [
|
||||
r["data"].insert(0, {
|
||||
"id": running_model,
|
||||
"object": "model",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"owned_by": opts.llm_middleware_name,
|
||||
"permission": [
|
||||
{
|
||||
"id": running_model,
|
||||
"object": "model",
|
||||
"object": "model_permission",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"owned_by": opts.llm_middleware_name,
|
||||
"permission": [
|
||||
{
|
||||
"id": running_model,
|
||||
"object": "model_permission",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"allow_create_engine": False,
|
||||
"allow_sampling": False,
|
||||
"allow_logprobs": False,
|
||||
"allow_search_indices": False,
|
||||
"allow_view": True,
|
||||
"allow_fine_tuning": False,
|
||||
"organization": "*",
|
||||
"group": None,
|
||||
"is_blocking": False
|
||||
}
|
||||
],
|
||||
"root": None,
|
||||
"parent": None
|
||||
"allow_create_engine": False,
|
||||
"allow_sampling": False,
|
||||
"allow_logprobs": False,
|
||||
"allow_search_indices": False,
|
||||
"allow_view": True,
|
||||
"allow_fine_tuning": False,
|
||||
"organization": "*",
|
||||
"group": None,
|
||||
"is_blocking": False
|
||||
}
|
||||
]
|
||||
}]
|
||||
response = jsonify_pretty(r + oai), 200
|
||||
],
|
||||
"root": None,
|
||||
"parent": None
|
||||
})
|
||||
response = jsonify_pretty(r), 200
|
||||
return response
|
||||
|
||||
|
||||
|
@ -64,7 +63,14 @@ def fetch_openai_models():
|
|||
if opts.openai_api_key:
|
||||
try:
|
||||
response = requests.get('https://api.openai.com/v1/models', headers={'Authorization': f"Bearer {opts.openai_api_key}"}, timeout=10)
|
||||
return response.json()['data']
|
||||
j = response.json()['data']
|
||||
|
||||
# The "modelperm" string appears to be user-specific, so we'll
|
||||
# randomize it just to be safe.
|
||||
for model in range(len(j)):
|
||||
for p in range(len(j[model]['permission'])):
|
||||
j[model]['permission'][p]['id'] = f'modelperm-{generate_oai_string(24)}'
|
||||
return j
|
||||
except:
|
||||
traceback.print_exc()
|
||||
return []
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from flask import jsonify
|
||||
|
||||
from . import openai_bp
|
||||
from ..cache import ONE_MONTH_SECONDS, flask_cache
|
||||
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache
|
||||
from ...llm.openai.transform import generate_oai_string
|
||||
from ..stats import server_start_time
|
||||
|
||||
|
@ -17,7 +17,7 @@ def openai_organizations():
|
|||
"id": f"org-{generate_oai_string(24)}",
|
||||
"created": int(server_start_time.timestamp()),
|
||||
"title": "Personal",
|
||||
"name": "user-abcdefghijklmnopqrstuvwx",
|
||||
"name": f"user-{generate_oai_string(24)}",
|
||||
"description": "Personal org for bobjoe@0.0.0.0",
|
||||
"personal": True,
|
||||
"is_default": True,
|
||||
|
|
|
@ -1,14 +1,21 @@
|
|||
import json
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
from typing import Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import flask
|
||||
from flask import jsonify
|
||||
from flask import Response, jsonify, make_response
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_model_choices
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import is_api_key_moderated
|
||||
from llm_server.llm.openai.transform import build_openai_response, transform_messages_to_prompt, trim_prompt_to_fit
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.llm.openai.oai_to_vllm import oai_to_vllm, validate_oai, return_invalid_model_err
|
||||
from llm_server.llm.openai.transform import ANTI_CONTINUATION_RE, ANTI_RESPONSE_RE, generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
from llm_server.routes.request_handler import RequestHandler
|
||||
from llm_server.workers.moderator import add_moderation_task, get_results
|
||||
|
||||
|
@ -20,20 +27,37 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
|
||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
assert not self.used
|
||||
if self.offline:
|
||||
msg = return_invalid_model_err(self.selected_model)
|
||||
print('OAI Offline:', msg)
|
||||
return self.handle_error(msg)
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
oai_messages = trim_prompt_to_fit(self.request.json['messages'], opts.context_size)
|
||||
oai_messages = trim_messages_to_fit(self.request.json['messages'], self.cluster_backend_info['model_config']['max_position_embeddings'], self.backend_url)
|
||||
else:
|
||||
oai_messages = self.request.json['messages']
|
||||
|
||||
self.prompt = transform_messages_to_prompt(oai_messages)
|
||||
self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
|
||||
if opts.openai_api_key and is_api_key_moderated(self.token):
|
||||
if not self.prompt:
|
||||
# TODO: format this as an openai error message
|
||||
return Response('Invalid prompt'), 400
|
||||
|
||||
# TODO: support Ooba backend
|
||||
self.parameters = oai_to_vllm(self.parameters, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
|
||||
invalid_oai_err_msg = validate_oai(self.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
|
||||
if opts.openai_moderation_enabled and opts.openai_api_key and is_api_key_moderated(self.token):
|
||||
try:
|
||||
# Gather the last message from the user and all preceeding system messages
|
||||
# Gather the last message from the user and all preceding system messages
|
||||
msg_l = self.request.json['messages'].copy()
|
||||
msg_l.reverse()
|
||||
tag = uuid4()
|
||||
|
@ -49,33 +73,40 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
self.prompt = transform_messages_to_prompt(self.request.json['messages'])
|
||||
except Exception as e:
|
||||
print(f'OpenAI moderation endpoint failed:', f'{e.__class__.__name__}: {e}')
|
||||
print(traceback.format_exc())
|
||||
|
||||
# Reconstruct the request JSON with the validated parameters and prompt.
|
||||
self.parameters['stop'].extend(['\n### INSTRUCTION', '\n### USER', '\n### ASSISTANT', '\n### RESPONSE'])
|
||||
if opts.openai_force_no_hashes:
|
||||
self.parameters['stop'].append('### ')
|
||||
|
||||
if opts.mode == 'vllm' and self.request_json_body.get('top_p') == 0:
|
||||
self.request_json_body['top_p'] = 0.01
|
||||
traceback.print_exc()
|
||||
|
||||
llm_request = {**self.parameters, 'prompt': self.prompt}
|
||||
(success, _, _, _), (backend_response, backend_response_status_code) = self.generate_response(llm_request)
|
||||
|
||||
model = self.request_json_body.get('model')
|
||||
|
||||
if success:
|
||||
return build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
|
||||
return self.build_openai_response(self.prompt, backend_response.json['results'][0]['text'], model=model), backend_response_status_code
|
||||
else:
|
||||
return backend_response, backend_response_status_code
|
||||
|
||||
def handle_ratelimited(self, do_log: bool = True):
|
||||
# TODO: return a simulated OpenAI error message
|
||||
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.
|
||||
return 'Ratelimited', 429
|
||||
model_choices, default_model = get_model_choices()
|
||||
default_model_info = model_choices[default_model]
|
||||
w = int(default_model_info['estimated_wait']) if default_model_info['estimated_wait'] > 0 else 2
|
||||
response = jsonify({
|
||||
"error": {
|
||||
"message": "Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues.",
|
||||
"type": "rate_limit_exceeded",
|
||||
"param": None,
|
||||
"code": None
|
||||
}
|
||||
})
|
||||
response.headers['x-ratelimit-limit-requests'] = '2'
|
||||
response.headers['x-ratelimit-remaining-requests'] = '0'
|
||||
response.headers['x-ratelimit-reset-requests'] = f"{w}s"
|
||||
|
||||
if do_log:
|
||||
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), response.data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||
|
||||
return response, 429
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
# TODO: return a simulated OpenAI error message
|
||||
print('OAI Error:', error_msg)
|
||||
return jsonify({
|
||||
"error": {
|
||||
"message": "Invalid request, check your parameters and try again.",
|
||||
|
@ -84,3 +115,51 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
"code": None
|
||||
}
|
||||
}), 400
|
||||
|
||||
def build_openai_response(self, prompt, response, model=None):
|
||||
# Seperate the user's prompt from the context
|
||||
x = prompt.split('### USER:')
|
||||
if len(x) > 1:
|
||||
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||||
|
||||
# Make sure the bot doesn't put any other instructions in its response
|
||||
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
||||
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
||||
|
||||
prompt_tokens = get_token_count(prompt, self.backend_url)
|
||||
response_tokens = get_token_count(response, self.backend_url)
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
response = make_response(jsonify({
|
||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model if opts.openai_expose_our_model else model,
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": response,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": response_tokens,
|
||||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
}), 200)
|
||||
return response
|
||||
|
||||
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||
self.parameters, parameters_invalid_msg = self.get_parameters()
|
||||
if not self.parameters:
|
||||
print('OAI BACKEND VALIDATION ERROR:', parameters_invalid_msg)
|
||||
return False, (Response('Invalid request, check your parameters and try again.'), 400)
|
||||
invalid_oai_err_msg = validate_oai(self.parameters)
|
||||
if invalid_oai_err_msg:
|
||||
return False, invalid_oai_err_msg
|
||||
# self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
|
||||
# If the parameters were invalid, let the superclass deal with it.
|
||||
return super().validate_request(prompt, do_log)
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import json
|
||||
import pickle
|
||||
import time
|
||||
from typing import Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
import ujson as json
|
||||
from redis import Redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
from llm_server.database.database import get_token_ratelimit
|
||||
|
||||
|
||||
def increment_ip_count(client_ip: str, redis_key):
|
||||
|
@ -20,24 +23,30 @@ def decrement_ip_count(client_ip: str, redis_key):
|
|||
|
||||
|
||||
class RedisPriorityQueue:
|
||||
def __init__(self):
|
||||
self.redis = Redis(host='localhost', port=6379, db=15)
|
||||
self.pubsub = self.redis.pubsub()
|
||||
self.pubsub.subscribe('events')
|
||||
"""
|
||||
A queue for a specific backend.
|
||||
"""
|
||||
|
||||
def put(self, item, priority):
|
||||
event = DataEvent()
|
||||
def __init__(self, name, db: int = 12):
|
||||
self.name = name
|
||||
self.redis = RedisCustom(name, db=db)
|
||||
|
||||
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
|
||||
# TODO: remove this when we're sure nothing strange is happening
|
||||
assert item is not None
|
||||
assert priority is not None
|
||||
assert selected_model is not None
|
||||
|
||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||
ip_count = self.redis.hget('queued_ip_count', item[1])
|
||||
if ip_count:
|
||||
ip_count = int(ip_count)
|
||||
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
||||
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
||||
ip_count = self.get_ip_request_count(item[1])
|
||||
_, simultaneous_ip = get_token_ratelimit(item[2])
|
||||
if ip_count and int(ip_count) >= simultaneous_ip and priority != 0:
|
||||
print(f'Rejecting request from {item[1]} - {ip_count} request queued.')
|
||||
return None # reject the request
|
||||
|
||||
self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority})
|
||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||
timestamp = time.time()
|
||||
event = DataEvent()
|
||||
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
|
||||
return event
|
||||
|
||||
def get(self):
|
||||
|
@ -45,31 +54,59 @@ class RedisPriorityQueue:
|
|||
data = self.redis.zpopmin('queue')
|
||||
if data:
|
||||
item = json.loads(data[0][0])
|
||||
client_ip = item[0][1]
|
||||
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
||||
return item
|
||||
time.sleep(0.1) # wait for something to be added to the queue
|
||||
|
||||
def increment_ip_count(self, client_ip: str, redis_key):
|
||||
self.redis.hincrby(redis_key, client_ip, 1)
|
||||
|
||||
def decrement_ip_count(self, client_ip: str, redis_key):
|
||||
new_count = self.redis.hincrby(redis_key, client_ip, -1)
|
||||
if new_count <= 0:
|
||||
self.redis.hdel(redis_key, client_ip)
|
||||
|
||||
def __len__(self):
|
||||
return self.redis.zcard('queue')
|
||||
|
||||
def get_queued_ip_count(self, client_ip: str):
|
||||
q = self.redis.hget('queued_ip_count', client_ip)
|
||||
if not q:
|
||||
return 0
|
||||
return 0
|
||||
def get_ip_request_count(self, client_ip: str):
|
||||
"""
|
||||
Get the number of requests in the queue from a specific IP.
|
||||
This is a bit inefficient since we iterate over the entire queue, but
|
||||
keeps the queue as a single point of truth instead of tracking a separate hashed
|
||||
set which can get confusing.
|
||||
If we run into slowdowns in the future, we should go back to the hashed set approach.
|
||||
:param client_ip:
|
||||
:return:
|
||||
"""
|
||||
start_time = time.time()
|
||||
items = self.redis.zrange('queue', 0, -1)
|
||||
count = 0
|
||||
for item in items:
|
||||
item_data = json.loads(item)
|
||||
if item_data[0][1] == client_ip:
|
||||
count += 1
|
||||
elapsed_time = time.time() - start_time
|
||||
if elapsed_time > 0.5:
|
||||
raise Exception(f"!!! get_ip_request_count took {elapsed_time} seconds to execute !!!")
|
||||
return count
|
||||
|
||||
def flush(self):
|
||||
self.redis.flush()
|
||||
|
||||
def items(self):
|
||||
return self.redis.zrange('queue', 0, -1)
|
||||
|
||||
def cleanup(self):
|
||||
now = time.time()
|
||||
for item in self.items():
|
||||
item_data = json.loads(item)
|
||||
timestamp = item_data[-2]
|
||||
if now - timestamp > opts.backend_generate_request_timeout:
|
||||
self.redis.zrem('queue', 0, item)
|
||||
event_id = item_data[1]
|
||||
event = DataEvent(event_id)
|
||||
event.set((False, None, 'closed'))
|
||||
print('Removed timed-out item from queue:', event_id)
|
||||
|
||||
|
||||
class DataEvent:
|
||||
def __init__(self, event_id=None):
|
||||
"""
|
||||
Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol).
|
||||
"""
|
||||
|
||||
def __init__(self, event_id: str = None):
|
||||
self.event_id = event_id if event_id else str(uuid4())
|
||||
self.redis = Redis(host='localhost', port=6379, db=14)
|
||||
self.pubsub = self.redis.pubsub()
|
||||
|
@ -84,15 +121,89 @@ class DataEvent:
|
|||
return pickle.loads(item['data'])
|
||||
|
||||
|
||||
priority_queue = RedisPriorityQueue()
|
||||
def update_active_workers(key: str, operation: str):
|
||||
if operation == 'incr':
|
||||
redis.incr(f'active_gen_workers:{key}')
|
||||
elif operation == 'decr':
|
||||
redis.decr(f'active_gen_workers:{key}')
|
||||
if redis.get(f'active_gen_workers:{key}', default=0, dtype=int) < 0:
|
||||
redis.set(f'active_gen_workers:{key}', 0)
|
||||
|
||||
|
||||
def incr_active_workers():
|
||||
redis.incr('active_gen_workers')
|
||||
def incr_active_workers(selected_model: str, backend_url: str):
|
||||
update_active_workers(selected_model, 'incr')
|
||||
update_active_workers(backend_url, 'incr')
|
||||
|
||||
|
||||
def decr_active_workers():
|
||||
redis.decr('active_gen_workers')
|
||||
new_count = redis.get('active_gen_workers', int, 0)
|
||||
if new_count < 0:
|
||||
redis.set('active_gen_workers', 0)
|
||||
def decr_active_workers(selected_model: str, backend_url: str):
|
||||
update_active_workers(selected_model, 'decr')
|
||||
update_active_workers(backend_url, 'decr')
|
||||
|
||||
|
||||
class PriorityQueue:
|
||||
"""
|
||||
Helper class to wrangler all the different queues.
|
||||
"""
|
||||
|
||||
def __init__(self, backends: set = None):
|
||||
"""
|
||||
Only have to load the backends once.
|
||||
:param backends:
|
||||
"""
|
||||
self.redis = Redis(host='localhost', port=6379, db=9)
|
||||
if backends:
|
||||
for item in backends:
|
||||
self.redis.sadd('backends', item)
|
||||
|
||||
def get_backends(self):
|
||||
return {x.decode('utf-8') for x in self.redis.smembers('backends')}
|
||||
|
||||
def get_queued_ip_count(self, client_ip: str):
|
||||
count = 0
|
||||
for backend_url in self.get_backends():
|
||||
queue = RedisPriorityQueue(backend_url)
|
||||
count += queue.get_ip_request_count(client_ip)
|
||||
return count
|
||||
|
||||
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
|
||||
queue = RedisPriorityQueue(backend_url)
|
||||
return queue.put(item, priority, selected_model, do_stream)
|
||||
|
||||
def activity(self):
|
||||
lines = []
|
||||
status_redis = RedisCustom('worker_status')
|
||||
for worker in status_redis.keys():
|
||||
lines.append((worker, status_redis.getp(worker)))
|
||||
return sorted(lines)
|
||||
|
||||
def len(self, model_name):
|
||||
count = 0
|
||||
backends_with_models = set()
|
||||
for k in self.get_backends():
|
||||
info = cluster_config.get_backend(k)
|
||||
if info.get('model') == model_name:
|
||||
backends_with_models.add(k)
|
||||
for backend_url in backends_with_models:
|
||||
count += len(RedisPriorityQueue(backend_url))
|
||||
return count
|
||||
|
||||
def __len__(self):
|
||||
count = 0
|
||||
p = set()
|
||||
for backend_url in self.get_backends():
|
||||
queue = RedisPriorityQueue(backend_url)
|
||||
p.add((backend_url, len(queue)))
|
||||
count += len(queue)
|
||||
return count
|
||||
|
||||
def flush(self):
|
||||
for k in self.redis.keys():
|
||||
q = json.loads(self.redis.get(k))
|
||||
q.flush()
|
||||
self.redis.set(k, json.dumps(q))
|
||||
|
||||
def flush_db(self):
|
||||
self.redis.flushdb()
|
||||
|
||||
|
||||
priority_queue = PriorityQueue()
|
||||
|
|
|
@ -5,23 +5,22 @@ import flask
|
|||
from flask import Response, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.cluster.cluster_config import cluster_config, get_a_cluster_backend
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import get_token_ratelimit
|
||||
from llm_server.database.log_to_db import log_to_db
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.auth import parse_token
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
|
||||
DEFAULT_PRIORITY = 9999
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
|
||||
def __init__(self, incoming_request: flask.Request, selected_model: str = None, incoming_json: Union[dict, str] = None):
|
||||
self.request = incoming_request
|
||||
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
||||
# self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
||||
|
||||
# Routes need to validate it, here we just load it
|
||||
if incoming_json:
|
||||
|
@ -34,11 +33,38 @@ class RequestHandler:
|
|||
self.start_time = time.time()
|
||||
self.client_ip = self.get_client_ip()
|
||||
self.token = self.get_auth_token()
|
||||
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
|
||||
self.backend = get_backend()
|
||||
self.token_priority, self.token_simultaneous_ip = get_token_ratelimit(self.token)
|
||||
self.parameters = None
|
||||
self.used = False
|
||||
redis.zadd('recent_prompters', {self.client_ip: time.time()})
|
||||
|
||||
# This is null by default since most handlers need to transform the prompt in a specific way.
|
||||
self.prompt = None
|
||||
|
||||
self.selected_model = selected_model
|
||||
self.backend_url = get_a_cluster_backend(selected_model)
|
||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||
|
||||
# Debug stuff
|
||||
# if not self.cluster_backend_info.get('mode'):
|
||||
# print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||
# if not self.cluster_backend_info.get('model'):
|
||||
# print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||
# if not self.cluster_backend_info.get('model_config'):
|
||||
# print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||
|
||||
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
|
||||
self.offline = True
|
||||
else:
|
||||
self.offline = False
|
||||
self.selected_model = self.cluster_backend_info['model']
|
||||
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
|
||||
if self.token and not self.token.startswith('SYSTEM__'):
|
||||
# "recent_prompters" is only used for stats.
|
||||
redis.zadd('recent_prompters', {self.client_ip: time.time()})
|
||||
|
||||
def check_online(self) -> bool:
|
||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||
return self.cluster_backend_info['online']
|
||||
|
||||
def get_auth_token(self):
|
||||
if self.request_json_body.get('X-API-KEY'):
|
||||
|
@ -49,6 +75,8 @@ class RequestHandler:
|
|||
return parse_token(self.request.headers['Authorization'])
|
||||
|
||||
def get_client_ip(self):
|
||||
if self.request.headers.get('Llm-Connecting-Ip'):
|
||||
return self.request.headers['Llm-Connecting-Ip']
|
||||
if self.request.headers.get('X-Connecting-IP'):
|
||||
return self.request.headers.get('X-Connecting-IP')
|
||||
elif self.request.headers.get('Cf-Connecting-Ip'):
|
||||
|
@ -58,26 +86,7 @@ class RequestHandler:
|
|||
else:
|
||||
return self.request.remote_addr
|
||||
|
||||
def get_token_ratelimit(self):
|
||||
priority = DEFAULT_PRIORITY
|
||||
simultaneous_ip = opts.simultaneous_requests_per_ip
|
||||
if self.token:
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
cursor.execute("SELECT priority, simultaneous_ip FROM token_auth WHERE token = %s", (self.token,))
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
priority, simultaneous_ip = result
|
||||
if simultaneous_ip is None:
|
||||
# No ratelimit for this token if null
|
||||
simultaneous_ip = 999999999
|
||||
finally:
|
||||
cursor.close()
|
||||
return priority, simultaneous_ip
|
||||
|
||||
def get_parameters(self):
|
||||
if self.request_json_body.get('max_tokens'):
|
||||
self.request_json_body['max_new_tokens'] = self.request_json_body.pop('max_tokens')
|
||||
parameters, parameters_invalid_msg = self.backend.get_parameters(self.request_json_body)
|
||||
return parameters, parameters_invalid_msg
|
||||
|
||||
|
@ -119,7 +128,7 @@ class RequestHandler:
|
|||
backend_response = self.handle_error(combined_error_message, 'Validation Error')
|
||||
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
log_to_db(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
|
||||
return False, backend_response
|
||||
return True, (None, 0)
|
||||
|
||||
|
@ -131,14 +140,18 @@ class RequestHandler:
|
|||
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
|
||||
if not request_valid:
|
||||
return (False, None, None, 0), invalid_response
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
|
||||
event = priority_queue.put(self.backend_url, (llm_request, self.client_ip, self.token, self.parameters), self.token_priority, self.selected_model)
|
||||
else:
|
||||
event = None
|
||||
|
||||
if not event:
|
||||
return (False, None, None, 0), self.handle_ratelimited()
|
||||
|
||||
# TODO: add wait timeout
|
||||
success, response, error_msg = event.wait()
|
||||
if error_msg == 'closed':
|
||||
return (False, None, None, 0), (self.handle_error('Request Timeout')[0], 408)
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
|
||||
|
@ -160,7 +173,17 @@ class RequestHandler:
|
|||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = self.handle_error(error_msg)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
log_to_db(ip=self.client_ip,
|
||||
token=self.token,
|
||||
prompt=prompt,
|
||||
response=backend_response[0].data.decode('utf-8'),
|
||||
gen_time=None,
|
||||
parameters=self.parameters,
|
||||
headers=dict(self.request.headers),
|
||||
backend_response_code=response_status_code,
|
||||
request_url=self.request.url,
|
||||
backend_url=self.backend_url,
|
||||
is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
# ===============================================
|
||||
|
@ -180,7 +203,7 @@ class RequestHandler:
|
|||
if return_json_err:
|
||||
error_msg = 'The backend did not return valid JSON.'
|
||||
backend_response = self.handle_error(error_msg)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
log_to_db(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
# ===============================================
|
||||
|
@ -189,22 +212,29 @@ class RequestHandler:
|
|||
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||
|
||||
def is_client_ratelimited(self) -> bool:
|
||||
if self.token_priority == 0:
|
||||
return False
|
||||
|
||||
queued_ip_count = int(priority_queue.get_queued_ip_count(self.client_ip))
|
||||
x = redis.hget('processing_ips', self.client_ip)
|
||||
if x:
|
||||
processing_ip = int(x)
|
||||
else:
|
||||
processing_ip = 0
|
||||
if queued_ip_count + processing_ip < self.token_simultaneous_ip or self.token_priority == 0:
|
||||
return False
|
||||
else:
|
||||
print(f'Rejecting request from {self.client_ip} - {queued_ip_count + processing_ip} queued + processing.')
|
||||
|
||||
if queued_ip_count + processing_ip >= self.token_simultaneous_ip:
|
||||
print(f'Rejecting request from {self.client_ip} - {processing_ip} processing, {queued_ip_count} queued')
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def handle_request(self) -> Tuple[flask.Response, int]:
|
||||
# Must include this in your child.
|
||||
# if self.used:
|
||||
# raise Exception('Can only use a RequestHandler object once.')
|
||||
# assert not self.used
|
||||
# if self.offline:
|
||||
# msg = f'{self.selected_model} is not a valid model choice.'
|
||||
# print(msg)
|
||||
# return format_sillytavern_err(msg)
|
||||
raise NotImplementedError
|
||||
|
||||
def handle_ratelimited(self, do_log: bool = True) -> Tuple[flask.Response, int]:
|
||||
|
@ -214,11 +244,11 @@ class RequestHandler:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_backend():
|
||||
if opts.mode == 'oobabooga':
|
||||
return OobaboogaBackend()
|
||||
elif opts.mode == 'vllm':
|
||||
return VLLMBackend()
|
||||
def get_backend_handler(mode, backend_url: str):
|
||||
if mode == 'oobabooga':
|
||||
return OobaboogaBackend(backend_url)
|
||||
elif mode == 'vllm':
|
||||
return VLLMBackend(backend_url)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
|
|
@ -1,3 +1,3 @@
|
|||
def handle_server_error(e):
|
||||
print(e)
|
||||
return {'error': True}, 500
|
||||
print('Internal Error:', e)
|
||||
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500
|
||||
|
|
|
@ -1,33 +1,11 @@
|
|||
from datetime import datetime
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
# proompters_5_min = 0
|
||||
# concurrent_semaphore = Semaphore(concurrent_gens)
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.helpers import round_up_base
|
||||
|
||||
server_start_time = datetime.now()
|
||||
|
||||
|
||||
# TODO: do I need this?
|
||||
# def elapsed_times_cleanup():
|
||||
# global wait_in_queue_elapsed
|
||||
# while True:
|
||||
# current_time = time.time()
|
||||
# with wait_in_queue_elapsed_lock:
|
||||
# global wait_in_queue_elapsed
|
||||
# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
|
||||
# time.sleep(1)
|
||||
|
||||
|
||||
def calculate_avg_gen_time():
|
||||
# Get the average generation time from Redis
|
||||
average_generation_time = redis.get('average_generation_time')
|
||||
if average_generation_time is None:
|
||||
return 0
|
||||
else:
|
||||
return float(average_generation_time)
|
||||
|
||||
|
||||
def get_total_proompts():
|
||||
count = redis.get('proompts')
|
||||
if count is None:
|
||||
|
@ -37,10 +15,27 @@ def get_total_proompts():
|
|||
return count
|
||||
|
||||
|
||||
def get_active_gen_workers():
|
||||
active_gen_workers = redis.get('active_gen_workers')
|
||||
if active_gen_workers is None:
|
||||
count = 0
|
||||
def get_active_gen_workers_model(selected_model: str = None):
|
||||
return redis.get(f'active_gen_workers:{selected_model}', dtype=int, default=0)
|
||||
|
||||
|
||||
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
|
||||
if active_gen_workers < concurrent_gens:
|
||||
return 0
|
||||
elif active_gen_workers >= concurrent_gens:
|
||||
# Calculate how long it will take to complete the currently running gens and the queued requests.
|
||||
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
|
||||
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
|
||||
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
|
||||
# Add gen_time_calc to the time to account for the currently running generations.
|
||||
# This assumes that all active workers will finish at the same time, which is unlikely.
|
||||
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
|
||||
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
|
||||
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
|
||||
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
|
||||
elif proompters_in_queue == 0 and active_gen_workers == 0:
|
||||
# No queue, no workers
|
||||
return 0
|
||||
else:
|
||||
count = int(active_gen_workers)
|
||||
return count
|
||||
# No queue
|
||||
return gen_time_calc
|
||||
|
|
|
@ -3,18 +3,18 @@ import traceback
|
|||
from flask import jsonify, request
|
||||
|
||||
from . import bp
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
|
||||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
def generate():
|
||||
@bp.route('/v1/generate', methods=['POST'])
|
||||
@bp.route('/<model_name>/v1/generate', methods=['POST'])
|
||||
def generate(model_name=None):
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
handler = OobaRequestHandler(request)
|
||||
handler = OobaRequestHandler(request, selected_model=model_name)
|
||||
try:
|
||||
return handler.handle_request()
|
||||
except Exception:
|
||||
|
|
|
@ -2,83 +2,30 @@ import time
|
|||
from datetime import datetime
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_model_choices
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import get_distinct_ips_24h, sum_column
|
||||
from llm_server.helpers import deep_sort, round_up_base
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.netdata import get_power_states
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
|
||||
from llm_server.helpers import deep_sort
|
||||
from llm_server.routes.stats import get_total_proompts, server_start_time
|
||||
|
||||
|
||||
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
|
||||
if active_gen_workers < concurrent_gens:
|
||||
return 0
|
||||
elif active_gen_workers >= concurrent_gens:
|
||||
# Calculate how long it will take to complete the currently running gens and the queued requests.
|
||||
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
|
||||
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
|
||||
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
|
||||
# Add gen_time_calc to the time to account for the currently running generations.
|
||||
# This assumes that all active workers will finish at the same time, which is unlikely.
|
||||
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
|
||||
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
|
||||
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
|
||||
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
|
||||
elif proompters_in_queue == 0 and active_gen_workers == 0:
|
||||
# No queue, no workers
|
||||
return 0
|
||||
else:
|
||||
# No queue
|
||||
return gen_time_calc
|
||||
|
||||
|
||||
# TODO: have routes/__init__.py point to the latest API version generate_stats()
|
||||
|
||||
def generate_stats(regen: bool = False):
|
||||
if not regen:
|
||||
c = redis.get('proxy_stats', dict)
|
||||
c = redis.getp('proxy_stats')
|
||||
if c:
|
||||
return c
|
||||
|
||||
model_name, error = get_running_model() # will return False when the fetch fails
|
||||
if isinstance(model_name, bool):
|
||||
online = False
|
||||
else:
|
||||
online = True
|
||||
redis.set('running_model', model_name)
|
||||
model_choices, default_model = get_model_choices(regen=True)
|
||||
|
||||
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
|
||||
# if len(t) == 0:
|
||||
# estimated_wait = 0
|
||||
# else:
|
||||
# waits = [elapsed for end, elapsed in t]
|
||||
# estimated_wait = int(sum(waits) / len(waits))
|
||||
|
||||
active_gen_workers = get_active_gen_workers()
|
||||
proompters_in_queue = len(priority_queue)
|
||||
|
||||
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
|
||||
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
|
||||
|
||||
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
|
||||
estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
|
||||
|
||||
if opts.netdata_root:
|
||||
netdata_stats = {}
|
||||
power_states = get_power_states()
|
||||
for gpu, power_state in power_states.items():
|
||||
netdata_stats[gpu] = {
|
||||
'power_state': power_state,
|
||||
# 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
|
||||
}
|
||||
else:
|
||||
netdata_stats = {}
|
||||
|
||||
base_client_api = redis.get('base_client_api', str)
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
|
||||
|
||||
output = {
|
||||
'models': {
|
||||
'choices': model_choices,
|
||||
'default': default_model,
|
||||
},
|
||||
'stats': {
|
||||
'proompters': {
|
||||
'5_min': proompters_5_min,
|
||||
|
@ -86,39 +33,49 @@ def generate_stats(regen: bool = False):
|
|||
},
|
||||
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
|
||||
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
|
||||
'average_generation_elapsed_sec': int(average_generation_time),
|
||||
# 'estimated_avg_tps': estimated_avg_tps,
|
||||
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
|
||||
'num_backends': len(cluster_config.all()) if opts.show_backends else None,
|
||||
},
|
||||
'online': online,
|
||||
'endpoints': {
|
||||
'blocking': f'https://{base_client_api}',
|
||||
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||
},
|
||||
'queue': {
|
||||
'processing': active_gen_workers,
|
||||
'queued': proompters_in_queue,
|
||||
'estimated_wait_sec': int(estimated_wait_sec),
|
||||
},
|
||||
'timestamp': int(time.time()),
|
||||
'config': {
|
||||
'gatekeeper': 'none' if opts.auth_required is False else 'token',
|
||||
'context_size': opts.context_size,
|
||||
'concurrent': opts.concurrent_gens,
|
||||
'model': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||
'mode': opts.mode,
|
||||
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
|
||||
'api_mode': opts.frontend_api_mode
|
||||
},
|
||||
'keys': {
|
||||
'openaiKeys': '∞',
|
||||
'anthropicKeys': '∞',
|
||||
},
|
||||
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
|
||||
'nvidia': netdata_stats
|
||||
'backends': {},
|
||||
'online': len(model_choices) > 0
|
||||
}
|
||||
|
||||
# TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself
|
||||
|
||||
if opts.show_backends:
|
||||
for backend_url, v in cluster_config.all().items():
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
if not backend_info['online']:
|
||||
continue
|
||||
backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None
|
||||
output['backends'][backend_info['hash']] = {
|
||||
'uptime': backend_uptime,
|
||||
'max_tokens': backend_info['model_config'].get('max_position_embeddings', -1),
|
||||
'model': backend_info['model'],
|
||||
'mode': backend_info['mode'],
|
||||
'nvidia': backend_info['nvidia'],
|
||||
'priority': backend_info['priority'],
|
||||
}
|
||||
|
||||
result = deep_sort(output)
|
||||
|
||||
# It may take a bit to get the base client API, so don't cache until then.
|
||||
if base_client_api:
|
||||
redis.set_dict('proxy_stats', result) # Cache with no expiry
|
||||
redis.setp('proxy_stats', result)
|
||||
|
||||
return result
|
||||
|
|
|
@ -1,186 +1,200 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
import ujson
|
||||
from flask import request
|
||||
from redis import Redis
|
||||
|
||||
from ..cache import redis
|
||||
from . import bp
|
||||
from ..helpers.http import require_api_key, validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.vllm import tokenize
|
||||
from ...stream import sock
|
||||
from ...custom_redis import redis
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...sock import sock
|
||||
|
||||
|
||||
# TODO: have workers process streaming requests
|
||||
# TODO: make sure to log the token as well (seems to be missing in the DB right now)
|
||||
# Stacking the @sock.route() creates a TypeError error on the /v1/stream endpoint.
|
||||
# We solve this by splitting the routes
|
||||
|
||||
@sock.route('/api/v1/stream')
|
||||
def stream(ws):
|
||||
def send_err_and_quit(quitting_err_msg):
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': 0,
|
||||
'text': quitting_err_msg
|
||||
}))
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': 1
|
||||
}))
|
||||
ws.close()
|
||||
log_in_bg(quitting_err_msg, is_error=True)
|
||||
@bp.route('/v1/stream')
|
||||
@bp.route('/<model_name>/v1/stream')
|
||||
def stream(model_name=None):
|
||||
return 'This is a websocket endpoint.', 400
|
||||
|
||||
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
|
||||
|
||||
def background_task_exception():
|
||||
generated_tokens = tokenize(generated_text_bg)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error)
|
||||
@sock.route('/v1/stream', bp=bp)
|
||||
def stream_without_model(ws):
|
||||
do_stream(ws, model_name=None)
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = threading.Thread(target=background_task_exception)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
if not opts.enable_streaming:
|
||||
return 'Streaming is disabled', 401
|
||||
@sock.route('/<model_name>/v1/stream', bp=bp)
|
||||
def stream_with_model(ws, model_name=None):
|
||||
do_stream(ws, model_name)
|
||||
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
message_num = 0
|
||||
while ws.connected:
|
||||
message = ws.receive()
|
||||
request_valid_json, request_json_body = validate_json(message)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return 'Invalid JSON', 400
|
||||
else:
|
||||
if opts.mode != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
auth_failure = require_api_key(request_json_body)
|
||||
if auth_failure:
|
||||
return auth_failure
|
||||
def do_stream(ws, model_name):
|
||||
event_id = None
|
||||
try:
|
||||
def send_err_and_quit(quitting_err_msg):
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': 0,
|
||||
'text': quitting_err_msg
|
||||
}))
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': 1
|
||||
}))
|
||||
ws.close()
|
||||
log_to_db(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=quitting_err_msg,
|
||||
gen_time=None,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=None,
|
||||
is_error=True
|
||||
)
|
||||
|
||||
handler = OobaRequestHandler(request, request_json_body)
|
||||
generated_text = ''
|
||||
input_prompt = request_json_body['prompt']
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
if not opts.enable_streaming:
|
||||
return 'Streaming disabled', 403
|
||||
|
||||
err_msg = None
|
||||
if handler.is_client_ratelimited():
|
||||
r, _ = handler.handle_ratelimited(do_log=False)
|
||||
err_msg = r.json['results'][0]['text']
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
message_num = 0
|
||||
|
||||
while ws.connected:
|
||||
message = ws.receive()
|
||||
request_valid_json, request_json_body = validate_json(message)
|
||||
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return 'Invalid JSON', 400
|
||||
else:
|
||||
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
|
||||
if not request_valid:
|
||||
err_msg = invalid_response[0].json['results'][0]['text']
|
||||
if err_msg:
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
# We have to do auth ourselves since the details are sent in the message.
|
||||
auth_failure = require_api_key(request_json_body)
|
||||
if auth_failure:
|
||||
return auth_failure
|
||||
|
||||
llm_request = {
|
||||
**handler.parameters,
|
||||
'prompt': input_prompt,
|
||||
'stream': True,
|
||||
}
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority)
|
||||
if not event:
|
||||
r, _ = handler.handle_ratelimited()
|
||||
err_msg = r.json['results'][0]['text']
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
try:
|
||||
response = generator(llm_request)
|
||||
if not response:
|
||||
error_msg = 'Failed to reach backend while streaming.'
|
||||
print('Streaming failed:', error_msg)
|
||||
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
|
||||
handler = OobaRequestHandler(incoming_request=request, selected_model=model_name, incoming_json=request_json_body)
|
||||
if handler.offline:
|
||||
msg = f'{handler.selected_model} is not a valid model choice.'
|
||||
print(msg)
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'message_num': 0,
|
||||
'text': msg
|
||||
}))
|
||||
return
|
||||
|
||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
input_prompt = request_json_body['prompt']
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
|
||||
err_msg = None
|
||||
if handler.is_client_ratelimited():
|
||||
r, _ = handler.handle_ratelimited(do_log=False)
|
||||
err_msg = r.json['results'][0]['text']
|
||||
else:
|
||||
# Be extra careful when getting attributes from the response object
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
request_valid, invalid_response = handler.validate_request(prompt=input_prompt)
|
||||
if not request_valid:
|
||||
err_msg = invalid_response[0].json['results'][0]['text']
|
||||
if err_msg:
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
|
||||
partial_response = b''
|
||||
handler.parameters, _ = handler.get_parameters()
|
||||
handler.prompt = input_prompt
|
||||
handler.request_json_body = {
|
||||
'prompt': handler.prompt,
|
||||
**handler.parameters
|
||||
}
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': new
|
||||
}))
|
||||
except:
|
||||
# The has client closed the stream.
|
||||
if request:
|
||||
request.close()
|
||||
ws.close()
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||
return
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||
if not event:
|
||||
r = handler.handle_ratelimited()
|
||||
send_err_and_quit(r[0].data)
|
||||
return
|
||||
event_id = event.event_id
|
||||
|
||||
_, stream_name, error_msg = event.wait()
|
||||
if error_msg:
|
||||
print('Stream failed to start streaming:', error_msg)
|
||||
ws.close(reason=1014, message='Request Timeout')
|
||||
return
|
||||
|
||||
stream_redis = Redis(db=8)
|
||||
generated_text = ''
|
||||
|
||||
try:
|
||||
last_id = '0-0' # The ID of the last entry we read.
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
|
||||
if not stream_data:
|
||||
print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
|
||||
return
|
||||
else:
|
||||
for stream_index, item in stream_data[0][1]:
|
||||
last_id = stream_index
|
||||
data = ujson.loads(item[b'data'])
|
||||
if data['error']:
|
||||
print(data['error'])
|
||||
send_err_and_quit('Encountered exception while streaming.')
|
||||
return
|
||||
elif data['new']:
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': data['new']
|
||||
}))
|
||||
message_num += 1
|
||||
partial_response = b'' # Reset the partial response
|
||||
|
||||
# If there is no more data, break the loop
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_in_bg(generated_text, elapsed_time_bg=elapsed_time, is_error=not response, status_code=response_status_code)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': generated_text
|
||||
}))
|
||||
if request:
|
||||
request.close()
|
||||
ws.close()
|
||||
log_in_bg(generated_text, is_error=True, status_code=response_status_code)
|
||||
return
|
||||
finally:
|
||||
# The worker incremented it, we'll decrement it.
|
||||
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||
decr_active_workers()
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
except:
|
||||
# The client closed the stream.
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||
ws.close() # this is important if we encountered and error and exited early.
|
||||
generated_text = generated_text + data['new']
|
||||
elif data['completed']:
|
||||
return
|
||||
except:
|
||||
send_err_and_quit('Encountered exception while streaming.')
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
except:
|
||||
# The client closed the stream.
|
||||
pass
|
||||
if stream_name:
|
||||
stream_redis.delete(stream_name)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
response=generated_text,
|
||||
gen_time=elapsed_time,
|
||||
parameters=handler.parameters,
|
||||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.backend_url
|
||||
)
|
||||
finally:
|
||||
if event_id:
|
||||
redis.publish(f'notifications:{event_id}', 'canceled')
|
||||
try:
|
||||
# Must close the connection or greenlets will complain.
|
||||
ws.close()
|
||||
except:
|
||||
pass
|
||||
|
|
|
@ -2,22 +2,16 @@ import time
|
|||
|
||||
from flask import jsonify, request
|
||||
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from . import bp
|
||||
from ..auth import requires_auth
|
||||
from ..cache import flask_cache
|
||||
from ... import opts
|
||||
from ...llm.info import get_running_model
|
||||
from ...cluster.backend import get_backends_from_model, is_valid_model
|
||||
from ...cluster.cluster_config import cluster_config, get_a_cluster_backend
|
||||
|
||||
|
||||
# @bp.route('/info', methods=['GET'])
|
||||
# # @cache.cached(timeout=3600, query_string=True)
|
||||
# def get_info():
|
||||
# # requests.get()
|
||||
# return 'yes'
|
||||
|
||||
|
||||
@bp.route('/model', methods=['GET'])
|
||||
def get_model():
|
||||
@bp.route('/v1/model', methods=['GET'])
|
||||
@bp.route('/<model_name>/v1/model', methods=['GET'])
|
||||
def get_model(model_name=None):
|
||||
# We will manage caching ourself since we don't want to cache
|
||||
# when the backend is down. Also, Cloudflare won't cache 500 errors.
|
||||
cache_key = 'model_cache::' + request.url
|
||||
|
@ -26,24 +20,21 @@ def get_model():
|
|||
if cached_response:
|
||||
return cached_response
|
||||
|
||||
model_name, error = get_running_model()
|
||||
if not model_name:
|
||||
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
|
||||
|
||||
if not is_valid_model(model_name):
|
||||
response = jsonify({
|
||||
'code': 502,
|
||||
'msg': 'failed to reach backend',
|
||||
'type': error.__class__.__name__
|
||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
'code': 400,
|
||||
'msg': 'Model does not exist.',
|
||||
}), 400
|
||||
else:
|
||||
num_backends = len(get_backends_from_model(model_name))
|
||||
response = jsonify({
|
||||
'result': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||
'model_backend_count': num_backends,
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
flask_cache.set(cache_key, response, timeout=60)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@bp.route('/backend', methods=['GET'])
|
||||
@requires_auth
|
||||
def get_backend():
|
||||
return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
from flask import jsonify
|
||||
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from . import bp
|
||||
from .generate_stats import generate_stats
|
||||
from ..cache import flask_cache
|
||||
from ..auth import requires_auth
|
||||
from ...cluster.cluster_config import cluster_config, get_backends
|
||||
from ...helpers import jsonify_pretty
|
||||
|
||||
|
||||
|
@ -10,3 +12,14 @@ from ...helpers import jsonify_pretty
|
|||
@flask_cache.cached(timeout=5, query_string=True)
|
||||
def get_stats():
|
||||
return jsonify_pretty(generate_stats())
|
||||
|
||||
|
||||
@bp.route('/backends', methods=['GET'])
|
||||
@requires_auth
|
||||
def get_backend():
|
||||
online, offline = get_backends()
|
||||
result = {}
|
||||
for i in online + offline:
|
||||
info = cluster_config.get_backend(i)
|
||||
result[info['hash']] = info
|
||||
return jsonify(result), 200
|
||||
|
|
|
@ -3,6 +3,6 @@ from flask_sock import Sock
|
|||
sock = Sock()
|
||||
|
||||
|
||||
def init_socketio(app):
|
||||
def init_wssocket(app):
|
||||
global sock
|
||||
sock.init_app(app)
|
|
@ -1,35 +0,0 @@
|
|||
from threading import Thread
|
||||
|
||||
from .blocking import start_workers
|
||||
from .main import main_background_thread
|
||||
from .moderator import start_moderation_workers
|
||||
from .printer import console_printer
|
||||
from .recent import recent_prompters_thread
|
||||
from .threads import cache_stats
|
||||
from .. import opts
|
||||
|
||||
|
||||
def start_background():
|
||||
start_workers(opts.concurrent_gens)
|
||||
|
||||
t = Thread(target=main_background_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the main background thread.')
|
||||
|
||||
start_moderation_workers(opts.openai_moderation_workers)
|
||||
|
||||
t = Thread(target=cache_stats)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the stats cacher.')
|
||||
|
||||
t = Thread(target=recent_prompters_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the recent proompters thread.')
|
||||
|
||||
t = Thread(target=console_printer)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the console printer.')
|
|
@ -1,51 +0,0 @@
|
|||
import threading
|
||||
import time
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
|
||||
|
||||
|
||||
def worker():
|
||||
while True:
|
||||
need_to_wait()
|
||||
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
||||
need_to_wait()
|
||||
|
||||
increment_ip_count(client_ip, 'processing_ips')
|
||||
incr_active_workers()
|
||||
|
||||
if not request_json_body:
|
||||
# This was a dummy request from the websocket handler.
|
||||
# We're going to let the websocket handler decrement processing_ips and active_gen_workers.
|
||||
continue
|
||||
|
||||
try:
|
||||
success, response, error_msg = generator(request_json_body)
|
||||
event = DataEvent(event_id)
|
||||
event.set((success, response, error_msg))
|
||||
finally:
|
||||
decrement_ip_count(client_ip, 'processing_ips')
|
||||
decr_active_workers()
|
||||
|
||||
|
||||
def start_workers(num_workers: int):
|
||||
i = 0
|
||||
for _ in range(num_workers):
|
||||
t = threading.Thread(target=worker)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
i += 1
|
||||
print(f'Started {i} inference workers.')
|
||||
|
||||
|
||||
def need_to_wait():
|
||||
# We need to check the number of active workers since the streaming endpoint may be doing something.
|
||||
active_workers = redis.get('active_gen_workers', int, 0)
|
||||
s = time.time()
|
||||
while active_workers >= opts.concurrent_gens:
|
||||
time.sleep(0.01)
|
||||
e = time.time()
|
||||
if e - s > 0.5:
|
||||
print(f'Worker was delayed {e - s} seconds.')
|
|
@ -0,0 +1,32 @@
|
|||
import time
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
|
||||
|
||||
|
||||
# NOT NEEDED
|
||||
|
||||
def cleaner():
|
||||
r = Redis(db=8)
|
||||
stream_info = {}
|
||||
|
||||
while True:
|
||||
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
|
||||
processed_streams = []
|
||||
for stream in all_streams:
|
||||
stream = stream.decode()
|
||||
current_size = r.xlen(stream)
|
||||
|
||||
# If the stream is new or its size has changed, update the size and time in the dictionary
|
||||
if stream not in stream_info or current_size != stream_info[stream]['size']:
|
||||
stream_info[stream] = {'size': current_size, 'time': time.time()}
|
||||
processed_streams.append(stream)
|
||||
else:
|
||||
# If the size hasn't changed for 5 minutes, delete the stream
|
||||
if time.time() - stream_info[stream]['time'] >= 300:
|
||||
r.delete(stream)
|
||||
print(f"Stream '{stream}' deleted due to inactivity.")
|
||||
del stream_info[stream]
|
||||
|
||||
time.sleep(60)
|
|
@ -0,0 +1,160 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from uuid import uuid4
|
||||
|
||||
import ujson
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
||||
|
||||
stream_redis = Redis(db=8)
|
||||
|
||||
STREAM_NAME_PREFIX = 'stream'
|
||||
|
||||
|
||||
def check_cancellation(event, event_id):
|
||||
"""
|
||||
This thread checks the pub/sub channel in the background so the main process
|
||||
isn't bogged down with Redis calls. Otherwise, the main process slows down to 1 token/sec.
|
||||
:param event:
|
||||
:param event_id:
|
||||
:return:
|
||||
"""
|
||||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(f'notifications:{event_id}')
|
||||
while not event.is_set():
|
||||
message = pubsub.get_message()
|
||||
if message and message['data'] == b'canceled':
|
||||
event.set()
|
||||
time.sleep(0.5) # check every half second
|
||||
|
||||
|
||||
def get_stream_name(name: str):
|
||||
return f'{STREAM_NAME_PREFIX}:{name}'
|
||||
|
||||
|
||||
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
|
||||
logger = create_logger('inferencer')
|
||||
prompt = msg_to_backend['prompt']
|
||||
stream_name = get_stream_name(stream_name)
|
||||
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
|
||||
event = threading.Event()
|
||||
threading.Thread(target=check_cancellation, args=(event, event_id)).start()
|
||||
try:
|
||||
response = generator(msg_to_backend, backend_url)
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
# If there is no more data, break the loop
|
||||
if not chunk:
|
||||
break
|
||||
if event.is_set():
|
||||
logger.debug('Client canceled generation')
|
||||
response.close()
|
||||
return
|
||||
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})})
|
||||
except AttributeError as e:
|
||||
if str(e) == "'bool' object has no attribute 'iter_content'":
|
||||
# We don't care about these errors.
|
||||
logger.debug('failed to stream from backend - no response')
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
|
||||
raise # We won't handle the exception here.
|
||||
finally:
|
||||
# Publish final message to Redis stream
|
||||
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})})
|
||||
event.set() # stop the cancellation checking thread
|
||||
|
||||
|
||||
#
|
||||
def worker(backend_url):
|
||||
logger = create_logger('inferencer')
|
||||
status_redis = RedisCustom('worker_status')
|
||||
worker_id = str(uuid4())
|
||||
status_redis.setp(str(worker_id), None)
|
||||
redis_queue = RedisPriorityQueue(backend_url)
|
||||
while True:
|
||||
status_redis.setp(str(worker_id), 'waiting...')
|
||||
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
|
||||
event = DataEvent(event_id)
|
||||
|
||||
try:
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
except:
|
||||
# This is not a critical error because it usually means that the backend is
|
||||
# offline and this backend is in a state of transition from online to offline.
|
||||
logger.debug(f'got an exception while getting info for backend {backend_url} - ', traceback.format_exc())
|
||||
event.set((False, None, 'exception'))
|
||||
continue
|
||||
|
||||
if not backend_info['online']:
|
||||
event.set((False, None, 'canceled'))
|
||||
continue
|
||||
|
||||
if not selected_model:
|
||||
selected_model = backend_info['model']
|
||||
|
||||
logger.debug(f"Starting using {backend_url} and {selected_model}. Online: {backend_info['online']}. Streaming: {do_stream}")
|
||||
|
||||
try:
|
||||
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
|
||||
increment_ip_count(client_ip, 'processing_ips')
|
||||
incr_active_workers(selected_model, backend_url)
|
||||
|
||||
if do_stream:
|
||||
status_redis.setp(str(worker_id), ('streaming', client_ip))
|
||||
|
||||
# Return the name of the stream that the slave should connect to.
|
||||
event.set((True, get_stream_name(worker_id), None))
|
||||
|
||||
msg_to_backend = {
|
||||
**parameters,
|
||||
'prompt': request_json_body['prompt'],
|
||||
'stream': True,
|
||||
}
|
||||
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
|
||||
else:
|
||||
# Normal inference (not streaming).
|
||||
status_redis.setp(str(worker_id), ('generating', client_ip))
|
||||
success, response, error_msg = generator(request_json_body, backend_url)
|
||||
event.set((success, response, error_msg))
|
||||
except:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set((False, None, 'exception'))
|
||||
finally:
|
||||
decrement_ip_count(client_ip, 'processing_ips')
|
||||
decr_active_workers(selected_model, backend_url)
|
||||
status_redis.setp(str(worker_id), None)
|
||||
|
||||
|
||||
def start_workers(cluster: dict):
|
||||
logger = create_logger('inferencer')
|
||||
i = 0
|
||||
for item in cluster:
|
||||
for _ in range(item['concurrent_gens']):
|
||||
t = threading.Thread(target=worker, args=(item['backend_url'],))
|
||||
t.daemon = True
|
||||
t.start()
|
||||
i += 1
|
||||
logger.info(f'Started {i} inference workers.')
|
|
@ -0,0 +1,31 @@
|
|||
import pickle
|
||||
import traceback
|
||||
|
||||
import redis
|
||||
|
||||
from llm_server.database.database import do_db_log
|
||||
|
||||
|
||||
def db_logger():
|
||||
"""
|
||||
We don't want the logging operation to be blocking, so we will use a background worker
|
||||
to do the logging.
|
||||
:return:
|
||||
"""
|
||||
|
||||
r = redis.Redis(host='localhost', port=6379, db=3)
|
||||
p = r.pubsub()
|
||||
p.subscribe('database-logger')
|
||||
|
||||
for message in p.listen():
|
||||
try:
|
||||
if message['type'] == 'message':
|
||||
data = pickle.loads(message['data'])
|
||||
function_name = data['function']
|
||||
args = data['args']
|
||||
kwargs = data['kwargs']
|
||||
|
||||
if function_name == 'log_prompt':
|
||||
do_db_log(*args, **kwargs)
|
||||
except:
|
||||
traceback.print_exc()
|
|
@ -1,56 +0,0 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
||||
def main_background_thread():
|
||||
redis.set('average_generation_elapsed_sec', 0)
|
||||
redis.set('estimated_avg_tps', 0)
|
||||
redis.set('average_output_tokens', 0)
|
||||
redis.set('backend_online', 0)
|
||||
redis.set_dict('backend_info', {})
|
||||
|
||||
while True:
|
||||
# TODO: unify this
|
||||
if opts.mode == 'oobabooga':
|
||||
running_model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
redis.set('running_model', running_model)
|
||||
redis.set('backend_online', 1)
|
||||
elif opts.mode == 'vllm':
|
||||
running_model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
redis.set('running_model', running_model)
|
||||
redis.set('backend_online', 1)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
if average_generation_elapsed_sec: # returns None on exception
|
||||
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
||||
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
||||
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
if average_generation_elapsed_sec:
|
||||
redis.set('average_output_tokens', average_output_tokens)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
|
||||
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
|
||||
|
||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||
redis.set('estimated_avg_tps', estimated_avg_tps)
|
||||
time.sleep(60)
|
|
@ -0,0 +1,57 @@
|
|||
import time
|
||||
|
||||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config, get_backends
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_info
|
||||
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
|
||||
|
||||
|
||||
def main_background_thread():
|
||||
while True:
|
||||
online, offline = get_backends()
|
||||
for backend_url in online:
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
backend_mode = backend_info['mode']
|
||||
backend_info = get_info(backend_url, backend_mode)
|
||||
running_model = backend_info.get('model')
|
||||
if not running_model:
|
||||
continue
|
||||
|
||||
average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode)
|
||||
if average_generation_elapsed_sec: # returns None on exception
|
||||
cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||
if average_output_tokens:
|
||||
cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens)
|
||||
if average_generation_elapsed_sec and average_output_tokens:
|
||||
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
|
||||
|
||||
if opts.background_homepage_cacher:
|
||||
try:
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
r = requests.get('https://' + base_client_api, timeout=5)
|
||||
except Exception as e:
|
||||
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
|
||||
|
||||
backends = priority_queue.get_backends()
|
||||
for backend_url in backends:
|
||||
queue = RedisPriorityQueue(backend_url)
|
||||
queue.cleanup()
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
|
||||
def calc_stats_for_backend(backend_url, running_model, backend_mode):
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
|
||||
running_model, backend_mode, backend_url, exclude_zeros=True,
|
||||
include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
|
||||
running_model, backend_mode, backend_url, exclude_zeros=True,
|
||||
include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||
return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps
|
|
@ -1,10 +1,13 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import redis as redis_redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
||||
from llm_server.logging import create_logger
|
||||
|
||||
redis_moderation = redis_redis.Redis()
|
||||
|
||||
|
@ -16,36 +19,43 @@ def start_moderation_workers(num_workers):
|
|||
t.daemon = True
|
||||
t.start()
|
||||
i += 1
|
||||
print(f'Started {i} moderation workers.')
|
||||
|
||||
|
||||
def moderation_worker():
|
||||
while True:
|
||||
result = redis_moderation.blpop('queue:msgs_to_check')
|
||||
try:
|
||||
msg, tag = json.loads(result[1])
|
||||
_, categories = check_moderation_endpoint(msg)
|
||||
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
||||
except:
|
||||
print(result)
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
|
||||
def add_moderation_task(msg, tag):
|
||||
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
|
||||
|
||||
# TODO: don't use UUID tags to identify items. Use native redis
|
||||
|
||||
def get_results(tag, num_tasks):
|
||||
tag = str(tag) # Required for comparison with Redis results.
|
||||
tag = str(tag) # Cast a UUID4 to a string.
|
||||
flagged_categories = set()
|
||||
num_results = 0
|
||||
start_time = time.time()
|
||||
while num_results < num_tasks:
|
||||
result = redis_moderation.blpop('queue:flagged_categories')
|
||||
result = redis_moderation.blpop(['queue:flagged_categories'], timeout=opts.openai_moderation_timeout)
|
||||
if result is None:
|
||||
break # Timeout occurred, break the loop.
|
||||
result_tag, categories = json.loads(result[1])
|
||||
if result_tag == tag:
|
||||
if categories:
|
||||
for item in categories:
|
||||
flagged_categories.add(item)
|
||||
num_results += 1
|
||||
if time.time() - start_time > opts.openai_moderation_timeout:
|
||||
logger.warning('Timed out waiting for result from moderator')
|
||||
break
|
||||
return list(flagged_categories)
|
||||
|
||||
|
||||
def moderation_worker():
|
||||
logger = create_logger('moderator')
|
||||
while True:
|
||||
result = redis_moderation.blpop(['queue:msgs_to_check'])
|
||||
try:
|
||||
msg, tag = json.loads(result[1])
|
||||
_, categories = check_moderation_endpoint(msg)
|
||||
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
||||
except:
|
||||
logger.error(traceback.format_exc())
|
||||
continue
|
||||
|
||||
|
||||
def add_moderation_task(msg, tag):
|
||||
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
|
||||
|
|
|
@ -1,25 +1,34 @@
|
|||
import logging
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.cluster.backend import get_running_models
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.queue import priority_queue
|
||||
|
||||
logger = logging.getLogger('console_printer')
|
||||
if not logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(logging.INFO)
|
||||
logger.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
|
||||
|
||||
def console_printer():
|
||||
logger = create_logger('console_printer')
|
||||
time.sleep(3)
|
||||
while True:
|
||||
processing = redis.hkeys('processing_ips')
|
||||
processing_count = 0
|
||||
for ip in processing:
|
||||
processing_count += int(redis.hget('processing_ips', ip))
|
||||
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}')
|
||||
try:
|
||||
processing = redis.keys('active_gen_workers:http*') # backends always start with http
|
||||
processing_count = 0
|
||||
if len(processing):
|
||||
for k in processing:
|
||||
processing_count += redis.get(k, default=0, dtype=int)
|
||||
backends = [k for k, v in cluster_config.all().items() if v['online']]
|
||||
activity = priority_queue.activity()
|
||||
|
||||
# Calculate the queue size the same way it's done on the stats.
|
||||
queue_size = 0
|
||||
running_models = get_running_models()
|
||||
for model in running_models:
|
||||
queue_size += priority_queue.len(model)
|
||||
|
||||
# Active Workers and Processing should read the same. If not, that's an issue.
|
||||
logger.info(f'Active Workers: {len([i for i in activity if (i[1] and i[1] != "waiting...")])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {len(backends)}')
|
||||
except:
|
||||
logger.error(traceback.format_exc())
|
||||
time.sleep(10)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
def recent_prompters_thread():
|
|
@ -0,0 +1,58 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.worker import cluster_worker
|
||||
from llm_server.logging import create_logger
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.inferencer import start_workers
|
||||
from llm_server.workers.logger import db_logger
|
||||
from llm_server.workers.mainer import main_background_thread
|
||||
from llm_server.workers.moderator import start_moderation_workers
|
||||
from llm_server.workers.printer import console_printer
|
||||
from llm_server.workers.recenter import recent_prompters_thread
|
||||
|
||||
|
||||
def cache_stats():
|
||||
while True:
|
||||
generate_stats(regen=True)
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def start_background():
|
||||
logger = create_logger('threader')
|
||||
start_workers(opts.cluster)
|
||||
|
||||
t = Thread(target=main_background_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
logger.info('Started the main background thread.')
|
||||
|
||||
num_moderators = opts.cluster_workers * 3
|
||||
start_moderation_workers(num_moderators)
|
||||
logger.info(f'Started {num_moderators} moderation workers.')
|
||||
|
||||
t = Thread(target=cache_stats)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
logger.info('Started the stats cacher.')
|
||||
|
||||
t = Thread(target=recent_prompters_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
logger.info('Started the recent proompters thread.')
|
||||
|
||||
t = Thread(target=console_printer)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
logger.info('Started the console logger.infoer.')
|
||||
|
||||
t = Thread(target=cluster_worker)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
logger.info('Started the cluster worker.')
|
||||
|
||||
t = Thread(target=db_logger)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
logger.info('Started background logger.')
|
|
@ -1,9 +0,0 @@
|
|||
import time
|
||||
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
|
||||
def cache_stats():
|
||||
while True:
|
||||
generate_stats(regen=True)
|
||||
time.sleep(5)
|
|
@ -0,0 +1,103 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import warnings
|
||||
from threading import Thread
|
||||
|
||||
import gradio as gr
|
||||
import openai
|
||||
import requests
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
API_BASE = os.getenv('API_BASE')
|
||||
if not API_BASE:
|
||||
print('Must set the secret variable API_BASE to your https://your-site/api')
|
||||
sys.exit(1)
|
||||
API_BASE = API_BASE.strip('/')
|
||||
|
||||
APP_TITLE = os.getenv('APP_TITLE')
|
||||
PRIMARY_MODEL_CHOICE = os.getenv('PRIMARY_MODEL_CHOICE')
|
||||
TRACKING_CODE = os.getenv('TRACKING_CODE')
|
||||
|
||||
|
||||
def background():
|
||||
while True:
|
||||
previous = openai.api_base
|
||||
try:
|
||||
r = requests.get(API_BASE + '/stats').json()
|
||||
if PRIMARY_MODEL_CHOICE in r['models']['choices'].keys():
|
||||
openai.api_base = API_BASE + '/openai/' + PRIMARY_MODEL_CHOICE + '/v1'
|
||||
else:
|
||||
openai.api_base = API_BASE + '/openai/v1'
|
||||
except:
|
||||
traceback.print_exc()
|
||||
openai.api_base = API_BASE + '/openai/v1'
|
||||
if openai.api_base != previous:
|
||||
print('Set primary model to', openai.api_base)
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
if PRIMARY_MODEL_CHOICE:
|
||||
t = Thread(target=background)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the background thread.')
|
||||
|
||||
# A system prompt can be injected into the very first spot in the context.
|
||||
# If the user sends a message that contains the CONTEXT_TRIGGER_PHRASE,
|
||||
# the content in CONTEXT_TRIGGER_INJECTION will be injected.
|
||||
# Setting CONTEXT_TRIGGER_PHRASE will also add it to the selectable examples.
|
||||
CONTEXT_TRIGGER_PHRASE = os.getenv('CONTEXT_TRIGGER_PHRASE')
|
||||
CONTEXT_TRIGGER_INJECTION = os.getenv('CONTEXT_TRIGGER_INJECTION')
|
||||
|
||||
openai.api_key = 'null'
|
||||
openai.api_base = API_BASE + '/openai/v1'
|
||||
|
||||
|
||||
def stream_response(prompt, history):
|
||||
messages = []
|
||||
do_injection = False
|
||||
for human, assistant in history:
|
||||
messages.append({'role': 'user', 'content': str(human)})
|
||||
messages.append({'role': 'assistant', 'content': str(assistant)})
|
||||
|
||||
if CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in human:
|
||||
do_injection = True
|
||||
messages.append({'role': 'user', 'content': prompt})
|
||||
|
||||
if do_injection or (CONTEXT_TRIGGER_INJECTION and CONTEXT_TRIGGER_PHRASE in prompt):
|
||||
messages.insert(0, {'role': 'system', 'content': CONTEXT_TRIGGER_INJECTION})
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(
|
||||
model='0',
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=300,
|
||||
stream=True,
|
||||
headers={'LLM-Source': 'huggingface-demo'}
|
||||
)
|
||||
except Exception:
|
||||
raise gr.Error("Failed to reach inference endpoint.")
|
||||
|
||||
message = ''
|
||||
for chunk in response:
|
||||
if len(chunk['choices'][0]['delta']) != 0:
|
||||
message += chunk['choices'][0]['delta']['content']
|
||||
yield message
|
||||
|
||||
|
||||
examples = ["hello"]
|
||||
if CONTEXT_TRIGGER_PHRASE:
|
||||
examples.insert(0, CONTEXT_TRIGGER_PHRASE)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as demo:
|
||||
gr.ChatInterface(stream_response, examples=examples, title=APP_TITLE, analytics_enabled=False, cache_examples=False, css='#component-0{height:100%!important}')
|
||||
|
||||
if TRACKING_CODE:
|
||||
print('Inserting tracking code')
|
||||
gr.HTML(TRACKING_CODE)
|
||||
|
||||
demo.queue(concurrency_count=1, api_open=False).launch(show_api=False)
|
|
@ -0,0 +1,3 @@
|
|||
gradio
|
||||
openai
|
||||
requests
|
|
@ -1,3 +1,8 @@
|
|||
"""
|
||||
This file is used to run certain tasks when the HTTP server starts.
|
||||
It's located here so it doesn't get imported with daemon.py
|
||||
"""
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
server
|
||||
{
|
||||
listen 443 ssl http2 default_server;
|
||||
server_name _;
|
||||
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header Connection $http_connection;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Scheme $scheme;
|
||||
|
||||
location ~* ^/api/(.*?|v1|openai)/(v1|(generate|stream)|(chat/completions|completions))$
|
||||
{
|
||||
# Route to inference endpoints
|
||||
proxy_pass http://127.0.0.1:5000;
|
||||
|
||||
# Required for streaming (both websockets and SSE).
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
|
||||
# Set long timeouts for inference operations.
|
||||
# Cloudflare has a timeout of 100 seconds.
|
||||
proxy_read_timeout 120;
|
||||
proxy_connect_timeout 120;
|
||||
proxy_send_timeout 120;
|
||||
}
|
||||
|
||||
location /
|
||||
{
|
||||
proxy_pass http://127.0.0.1:5000;
|
||||
}
|
||||
|
||||
ssl_certificate /etc/ssl/certs/nginx-selfsigned.crt;
|
||||
ssl_certificate_key /etc/ssl/private/nginx-selfsigned.key;
|
||||
include /etc/nginx/snippets/ssl-params.conf;
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
HOST="proxy.chub-archive.evulid.cc"
|
||||
|
||||
AUTH_KEY="TEST_1df979f0-6df1-41bd-814a-e99b1680e727"
|
||||
|
||||
PROXY_SERVERS=(
|
||||
"http://172.0.4.7:3128"
|
||||
"http://172.0.4.8:3128"
|
||||
"http://172.0.4.10:3128"
|
||||
"http://172.0.4.12:3128"
|
||||
"http://172.0.4.13:3128"
|
||||
)
|
|
@ -0,0 +1,58 @@
|
|||
#!/bin/bash
|
||||
|
||||
SLEEP_TIME=2
|
||||
|
||||
while getopts p:t: flag; do
|
||||
case "${flag}" in
|
||||
p) PROXY_CHOICE=${OPTARG} ;;
|
||||
t) SLEEP_TIME=${OPTARG} ;;
|
||||
*) ;;
|
||||
esac
|
||||
done
|
||||
|
||||
SOURCE=${BASH_SOURCE[0]}
|
||||
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
|
||||
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
|
||||
SOURCE=$(readlink "$SOURCE")
|
||||
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
|
||||
done
|
||||
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
|
||||
|
||||
source "$DIR/config.sh"
|
||||
|
||||
if [ -n "$PROXY_CHOICE" ]; then
|
||||
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
|
||||
echo "Using $our_proxy_server"
|
||||
else
|
||||
our_proxy_server=""
|
||||
fi
|
||||
|
||||
while true; do
|
||||
echo "--> START <--"
|
||||
|
||||
DATA=$(
|
||||
cat <<EOF
|
||||
{
|
||||
"prompt": "Write a 300 word story about an apple tree.",
|
||||
"temperature": 1,
|
||||
"max_new_tokens": 100,
|
||||
"top_p": 1.0,
|
||||
"top_k": -1,
|
||||
"use_beam_search": false,
|
||||
"stop": ["TEST"],
|
||||
"ignore_eos": false,
|
||||
"presence_penalty": 0.0,
|
||||
"frequency_penalty": 0.0,
|
||||
"length_penalty": 1.0,
|
||||
"early_stopping": false
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
curl "https://$HOST/api/v1/generate" -m 100 -x "$our_proxy_server" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $AUTH_KEY" \
|
||||
-d "$DATA"
|
||||
echo -e "--> DONE <--\n"
|
||||
sleep $SLEEP_TIME
|
||||
done
|
|
@ -0,0 +1,52 @@
|
|||
#!/bin/bash
|
||||
|
||||
DO_STREAM=false
|
||||
SLEEP_TIME=2
|
||||
|
||||
while getopts p:t:s flag; do
|
||||
case "${flag}" in
|
||||
s) DO_STREAM=true ;;
|
||||
p) PROXY_CHOICE=${OPTARG} ;;
|
||||
t) SLEEP_TIME=${OPTARG} ;;
|
||||
*) ;;
|
||||
esac
|
||||
done
|
||||
|
||||
SOURCE=${BASH_SOURCE[0]}
|
||||
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
|
||||
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
|
||||
SOURCE=$(readlink "$SOURCE")
|
||||
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
|
||||
done
|
||||
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
|
||||
|
||||
source "$DIR/config.sh"
|
||||
|
||||
if [ ! -z "$PROXY_CHOICE" ]; then
|
||||
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
|
||||
echo "Using $our_proxy_server"
|
||||
else
|
||||
our_proxy_server=""
|
||||
fi
|
||||
|
||||
while true; do
|
||||
echo "--> START <--"
|
||||
|
||||
DATA=$(
|
||||
cat <<EOF
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"messages": [{"role": "user", "content": "Write a 300 word story about an apple tree."}],
|
||||
"max_tokens": 100,
|
||||
"stream": $DO_STREAM
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
curl "https://$HOST/api/openai/v1/chat/completions" -m 100 -x "$our_proxy_server" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $AUTH_KEY" \
|
||||
-d "$DATA"
|
||||
echo -e "--> DONE <--\n"
|
||||
sleep $SLEEP_TIME
|
||||
done
|
|
@ -0,0 +1,52 @@
|
|||
#!/bin/bash
|
||||
|
||||
DO_STREAM=false
|
||||
SLEEP_TIME=2
|
||||
|
||||
while getopts p:t:s flag; do
|
||||
case "${flag}" in
|
||||
s) DO_STREAM=true ;;
|
||||
p) PROXY_CHOICE=${OPTARG} ;;
|
||||
t) SLEEP_TIME=${OPTARG} ;;
|
||||
*) ;;
|
||||
esac
|
||||
done
|
||||
|
||||
SOURCE=${BASH_SOURCE[0]}
|
||||
while [ -L "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
|
||||
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
|
||||
SOURCE=$(readlink "$SOURCE")
|
||||
[[ $SOURCE != /* ]] && SOURCE=$DIR/$SOURCE # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
|
||||
done
|
||||
DIR=$(cd -P "$(dirname "$SOURCE")" >/dev/null 2>&1 && pwd)
|
||||
|
||||
source "$DIR/config.sh"
|
||||
|
||||
if [ ! -z "$PROXY_CHOICE" ]; then
|
||||
our_proxy_server="${PROXY_SERVERS[$PROXY_CHOICE]}"
|
||||
echo "Using $our_proxy_server"
|
||||
else
|
||||
our_proxy_server=""
|
||||
fi
|
||||
|
||||
while true; do
|
||||
echo "--> START <--"
|
||||
|
||||
DATA=$(
|
||||
cat <<EOF
|
||||
{
|
||||
"model": "gpt-4",
|
||||
"prompt": "Write a 300 word story about an apple tree.",
|
||||
"max_tokens": 100,
|
||||
"stream": $DO_STREAM
|
||||
}
|
||||
EOF
|
||||
)
|
||||
|
||||
curl "https://$HOST/api/openai/v1/completions" -m 100 -x "$our_proxy_server" \
|
||||
-H "Content-Type: application/json" \
|
||||
-H "Authorization: Bearer $AUTH_KEY" \
|
||||
-d "$DATA"
|
||||
echo -e "--> DONE <--\n"
|
||||
sleep $SLEEP_TIME
|
||||
done
|
|
@ -0,0 +1,64 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Function to display help message
|
||||
function display_help {
|
||||
echo "Usage: $0 -n num_windows -c command"
|
||||
echo
|
||||
echo " -n, --number Number of windows to create"
|
||||
echo " -c, --command Command to run in each window"
|
||||
echo
|
||||
exit 1
|
||||
}
|
||||
|
||||
# Parse command line arguments
|
||||
while getopts "n:c:h" opt; do
|
||||
case ${opt} in
|
||||
n)
|
||||
num_windows=${OPTARG}
|
||||
;;
|
||||
c)
|
||||
command=${OPTARG}
|
||||
;;
|
||||
h)
|
||||
display_help
|
||||
;;
|
||||
\?)
|
||||
echo "Invalid option: -$OPTARG" 1>&2
|
||||
display_help
|
||||
;;
|
||||
:)
|
||||
echo "Option -$OPTARG requires an argument." 1>&2
|
||||
display_help
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check if number of windows and command are provided
|
||||
if [ -z "$num_windows" ] || [ -z "$command" ]; then
|
||||
echo "Both number of windows and command are required."
|
||||
display_help
|
||||
fi
|
||||
|
||||
# Calculate rows and columns
|
||||
rows=$(echo "sqrt($num_windows)" | bc)
|
||||
columns=$(echo "($num_windows + $rows - 1) / $rows" | bc)
|
||||
|
||||
# Create a new tmux session
|
||||
tmux new-session -d -s llm_tester "$command -p 0"
|
||||
|
||||
# Create the remaining windows
|
||||
for ((i = 1; i < $num_windows; i++)); do
|
||||
if ((i % $columns == 0)); then
|
||||
tmux select-layout -t llm_tester:0 tiled
|
||||
tmux select-pane -t 0
|
||||
tmux split-window -t llm_tester:0 -v "$command -p $i"
|
||||
else
|
||||
tmux split-window -t llm_tester:0 -h "$command -p $i"
|
||||
fi
|
||||
done
|
||||
|
||||
# Balance the windows
|
||||
tmux select-layout -t llm_tester:0 tiled
|
||||
|
||||
# Attach to the session
|
||||
tmux attach-session -t llm_tester
|
|
@ -1,37 +1,50 @@
|
|||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Websockets package not found. Make sure it's installed.")
|
||||
|
||||
# For local streaming, the websockets are hosted without ssl - ws://
|
||||
HOST = 'localhost:5000'
|
||||
URI = f'ws://{HOST}/api/v1/stream'
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||
|
||||
def parse_bash_config(file_path):
|
||||
config = {}
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f:
|
||||
if line.startswith('#') or '=' not in line:
|
||||
continue
|
||||
key, value = line.strip().split('=', 1)
|
||||
if value.startswith('"') and value.endswith('"'):
|
||||
value = value[1:-1]
|
||||
elif value.startswith('(') and value.endswith(')'):
|
||||
value = value[1:-1].split()
|
||||
value = [v.strip('"') for v in value]
|
||||
config[key] = value
|
||||
return config
|
||||
|
||||
|
||||
config = parse_bash_config(Path(script_path, 'config.sh'))
|
||||
|
||||
|
||||
async def run(context):
|
||||
# Note: the selected defaults change from time to time.
|
||||
request = {
|
||||
'prompt': context,
|
||||
'max_new_tokens': 250,
|
||||
'auto_max_new_tokens': False,
|
||||
'max_tokens_second': 0,
|
||||
|
||||
# Generation params. If 'preset' is set to different than 'None', the values
|
||||
# in presets/preset-name.yaml are used instead of the individual numbers.
|
||||
'preset': 'None',
|
||||
'do_sample': True,
|
||||
'temperature': 0.7,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'epsilon_cutoff': 0, # In units of 1e-4
|
||||
'eta_cutoff': 0, # In units of 1e-4
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
'tfs': 1,
|
||||
'top_a': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
|
@ -48,7 +61,6 @@ async def run(context):
|
|||
'mirostat_eta': 0.1,
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'truncation_length': 2048,
|
||||
|
@ -58,7 +70,7 @@ async def run(context):
|
|||
'stopping_strings': []
|
||||
}
|
||||
|
||||
async with websockets.connect(URI, ping_interval=None) as websocket:
|
||||
async with websockets.connect(f'wss://{config["HOST"]}/api/v1/stream', ping_interval=None) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
yield context # Remove this if you just want to see the reply
|
||||
|
@ -67,20 +79,28 @@ async def run(context):
|
|||
incoming_data = await websocket.recv()
|
||||
incoming_data = json.loads(incoming_data)
|
||||
|
||||
print(incoming_data)
|
||||
|
||||
match incoming_data['event']:
|
||||
case 'text_stream':
|
||||
yield incoming_data['text']
|
||||
# case 'text_stream':
|
||||
# yield incoming_data['text']
|
||||
case 'stream_end':
|
||||
return
|
||||
|
||||
|
||||
async def print_response_stream(prompt):
|
||||
async for response in run(prompt):
|
||||
print(response, end='')
|
||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||
print('\n\nfinished')
|
||||
try:
|
||||
async for response in run(prompt):
|
||||
print(response, end='')
|
||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||
asyncio.run(print_response_stream(prompt))
|
||||
prompt = "Write a 300 word story about an apple tree.\n\n"
|
||||
while True:
|
||||
print('--> START <--')
|
||||
asyncio.run(print_response_stream(prompt))
|
||||
print('--> DONE <--')
|
||||
time.sleep(2)
|
|
@ -1,21 +1,18 @@
|
|||
flask~=2.3.3
|
||||
flask_cors
|
||||
pyyaml~=6.0.1
|
||||
flask_caching
|
||||
Flask-Caching==2.0.2
|
||||
requests~=2.31.0
|
||||
tiktoken~=0.5.0
|
||||
gunicorn
|
||||
gevent~=23.9.0.post1
|
||||
async-timeout
|
||||
flask-sock
|
||||
uvicorn~=0.23.2
|
||||
fastapi~=0.103.1
|
||||
torch~=2.0.1
|
||||
PyMySQL~=1.1.0
|
||||
DBUtils~=3.0.3
|
||||
simplejson~=3.19.1
|
||||
websockets~=11.0.3
|
||||
basicauth~=1.0.0
|
||||
openai~=0.28.0
|
||||
urllib3~=2.0.4
|
||||
celery[redis]
|
||||
flask-sock==0.6.0
|
||||
gunicorn==21.2.0
|
||||
redis==5.0.1
|
||||
ujson==5.8.0
|
||||
vllm==0.2.1.post1
|
||||
gradio~=3.46.1
|
||||
coloredlogs~=15.0.1
|
138
server.py
138
server.py
|
@ -1,5 +1,3 @@
|
|||
from llm_server.config.config import mode_ui_names
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
|
@ -7,37 +5,46 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
from llm_server.pre_fork import server_startup
|
||||
from llm_server.config.load import load_config
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import simplejson as json
|
||||
from flask import Flask, jsonify, render_template, request
|
||||
from flask import Flask, jsonify, render_template, request, Response
|
||||
|
||||
import llm_server
|
||||
import config
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_model_choices
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.config.config import mode_ui_names
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.custom_redis import flask_cache, redis
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.routes.openai import openai_bp
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.pre_fork import server_startup
|
||||
from llm_server.routes.openai import openai_bp, openai_model_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
from llm_server.routes.v1 import bp
|
||||
from llm_server.stream import init_socketio
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.sock import init_wssocket
|
||||
|
||||
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
|
||||
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
|
||||
# TODO: implement background thread to test backends via sending test prompts
|
||||
# TODO: if backend fails request, mark it as down
|
||||
# TODO: allow setting concurrent gens per-backend
|
||||
# TODO: set the max tokens to that of the lowest backend
|
||||
# TODO: implement RRD backend loadbalancer option
|
||||
# TODO: have VLLM reject a request if it already has n == concurrent_gens running
|
||||
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
||||
# TODO: use coloredlogs
|
||||
# TODO: need to update opts. for workers
|
||||
# TODO: seperate queue item timeout for websockets (make longer, like 5 minutes)
|
||||
# TODO: return an `error: True`, error code, and error message rather than just a formatted message
|
||||
# TODO: what happens when all backends are offline? What about the "online" key in the stats page?
|
||||
# TODO: redis SCAN vs KEYS??
|
||||
# TODO: is frequency penalty the same as ooba repetition penalty???
|
||||
# TODO: make sure openai_moderation_enabled works on websockets, completions, and chat completions
|
||||
|
||||
# Lower priority
|
||||
# TODO: if a backend is at its limit of concurrent requests, choose a different one
|
||||
# TODO: make error messages consitient
|
||||
# TODO: support logit_bias on OpenAI and Ooba endpoints.
|
||||
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
||||
# TODO: validate openai_silent_trim works as expected and only when enabled
|
||||
# TODO: rewrite config storage. Store in redis so we can reload it.
|
||||
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
|
||||
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
|
||||
# TODO: the estiamted wait time lags behind the stats
|
||||
# TODO: simulate OpenAI error messages regardless of endpoint
|
||||
|
@ -59,19 +66,16 @@ except ModuleNotFoundError as e:
|
|||
print('Please see README.md for install instructions.')
|
||||
sys.exit(1)
|
||||
|
||||
import config
|
||||
from llm_server import opts
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.routes.cache import RedisWrapper, flask_cache
|
||||
from llm_server.llm import redis
|
||||
from llm_server.routes.stats import get_active_gen_workers
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
app = Flask(__name__)
|
||||
init_socketio(app)
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
|
||||
# Fixes ConcurrentObjectUseError
|
||||
# https://github.com/miguelgrinberg/simple-websocket/issues/24
|
||||
app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
|
||||
|
||||
app.register_blueprint(bp, url_prefix='/api/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
|
||||
init_wssocket(app)
|
||||
flask_cache.init_app(app)
|
||||
flask_cache.clear()
|
||||
|
||||
|
@ -82,18 +86,13 @@ if config_path_environ:
|
|||
else:
|
||||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
|
||||
success, config, msg = load_config(config_path, script_path)
|
||||
success, config, msg = load_config(config_path)
|
||||
if not success:
|
||||
print('Failed to load config:', msg)
|
||||
sys.exit(1)
|
||||
|
||||
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
||||
create_db()
|
||||
llm_server.llm.redis = RedisWrapper('local_llm')
|
||||
create_db()
|
||||
|
||||
|
||||
# print(app.url_map)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
|
@ -101,20 +100,30 @@ create_db()
|
|||
@app.route('/api/openai')
|
||||
@flask_cache.cached(timeout=10)
|
||||
def home():
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
stats = generate_stats()
|
||||
model_choices, default_model = get_model_choices()
|
||||
|
||||
if not stats['online']:
|
||||
running_model = estimated_wait_sec = 'offline'
|
||||
else:
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
if default_model:
|
||||
if not model_choices.get(default_model):
|
||||
return 'The server is still starting up. Please wait...'
|
||||
|
||||
active_gen_workers = get_active_gen_workers()
|
||||
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
|
||||
default_model_info = model_choices[default_model]
|
||||
|
||||
if default_model_info['queued'] == 0 and default_model_info['queued'] >= default_model_info['concurrent_gens']:
|
||||
# There will be a wait if the queue is empty but prompts are processing, but we don't
|
||||
# know how long.
|
||||
estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"
|
||||
default_estimated_wait_sec = f"less than {int(default_model_info['estimated_wait'])} seconds"
|
||||
else:
|
||||
estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds"
|
||||
default_estimated_wait_sec = f"{int(default_model_info['estimated_wait'])} seconds"
|
||||
else:
|
||||
default_model_info = {
|
||||
'model': 'OFFLINE',
|
||||
'processing': '-',
|
||||
'queued': '-',
|
||||
'context_size': '-',
|
||||
}
|
||||
default_estimated_wait_sec = 'OFFLINE'
|
||||
|
||||
if len(config['analytics_tracking_code']):
|
||||
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
|
||||
|
@ -127,32 +136,47 @@ def home():
|
|||
info_html = ''
|
||||
|
||||
mode_info = ''
|
||||
if opts.mode == 'vllm':
|
||||
mode_info = vllm_info
|
||||
|
||||
base_client_api = redis.get('base_client_api', str)
|
||||
for k, v in cluster_config.all().items():
|
||||
if v['mode'] == 'vllm':
|
||||
mode_info = vllm_info
|
||||
break
|
||||
|
||||
return render_template('home.html',
|
||||
llm_middleware_name=opts.llm_middleware_name,
|
||||
analytics_tracking_code=analytics_tracking_code,
|
||||
info_html=info_html,
|
||||
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
|
||||
default_model=default_model_info['model'],
|
||||
default_active_gen_workers=default_model_info['processing'],
|
||||
default_proompters_in_queue=default_model_info['queued'],
|
||||
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
|
||||
client_api=f'https://{base_client_api}',
|
||||
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||
estimated_wait=estimated_wait_sec,
|
||||
mode_name=mode_ui_names[opts.mode][0],
|
||||
api_input_textbox=mode_ui_names[opts.mode][1],
|
||||
streaming_input_textbox=mode_ui_names[opts.mode][2],
|
||||
context_size=opts.context_size,
|
||||
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else 'disabled',
|
||||
default_estimated_wait=default_estimated_wait_sec,
|
||||
mode_name=mode_ui_names[opts.frontend_api_mode][0],
|
||||
api_input_textbox=mode_ui_names[opts.frontend_api_mode][1],
|
||||
streaming_input_textbox=mode_ui_names[opts.frontend_api_mode][2],
|
||||
default_context_size=default_model_info['context_size'],
|
||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||
extra_info=mode_info,
|
||||
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
expose_openai_system_prompt=opts.expose_openai_system_prompt,
|
||||
enable_streaming=opts.enable_streaming,
|
||||
model_choices=model_choices,
|
||||
proompters_5_min=stats['stats']['proompters']['5_min'],
|
||||
proompters_24_hrs=stats['stats']['proompters']['24_hrs'],
|
||||
)
|
||||
|
||||
|
||||
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
|
||||
@app.route('/robots.txt')
|
||||
def robots():
|
||||
# TODO: have config value to deny all
|
||||
# TODO: https://developers.google.com/search/docs/crawling-indexing/robots/create-robots-txt
|
||||
t = """User-agent: *
|
||||
Allow: /"""
|
||||
r = Response(t)
|
||||
r.headers['Content-Type'] = 'text/plain'
|
||||
return r
|
||||
|
||||
|
||||
@app.route('/<first>')
|
||||
@app.route('/<first>/<path:rest>')
|
||||
|
|
|
@ -65,6 +65,19 @@
|
|||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.header-workers {
|
||||
font-weight: normal;
|
||||
font-size: 14pt;
|
||||
}
|
||||
|
||||
h3 {
|
||||
font-size: 16pt;
|
||||
}
|
||||
|
||||
.no-marker {
|
||||
list-style: none;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
|
@ -76,8 +89,12 @@
|
|||
<h1 style="text-align: center;margin-top: 0;">{{ llm_middleware_name }}</h1>
|
||||
|
||||
<div class="info-box">
|
||||
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
|
||||
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
|
||||
<p><strong>Current Model:</strong> <span id="model">{{ default_model }}</span></p>
|
||||
<p>
|
||||
<strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ default_estimated_wait }}</span><br>
|
||||
Processing: {{ default_active_gen_workers }}<br>
|
||||
Queued: {{ default_proompters_in_queue }}
|
||||
</p>
|
||||
<br>
|
||||
<p><strong>Client API URL:</strong> {{ client_api }}</p>
|
||||
<p><strong>Streaming API URL:</strong> {{ ws_client_api if enable_streaming else 'Disabled' }}</p>
|
||||
|
@ -91,17 +108,20 @@
|
|||
<br>
|
||||
|
||||
<div class="info-box">
|
||||
<div id="oobabooga">
|
||||
<strong>Instructions:</strong>
|
||||
<h3>Instructions</h3>
|
||||
<div id="instructions">
|
||||
<ol>
|
||||
<li>In Settings > Power User Options, enable <kbd>Relaxed API URLS</kbd>.</li>
|
||||
<li>Set your API type to <kbd>{{ mode_name }}</kbd></li>
|
||||
<li>Enter <kbd>{{ client_api }}</kbd> in the <kbd>{{ api_input_textbox }}</kbd> textbox.</li>
|
||||
{% if enable_streaming %}<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>{% endif %}
|
||||
{% if enable_streaming %}
|
||||
<li>Enter <kbd>{{ ws_client_api }}</kbd> in the <kbd>{{ streaming_input_textbox }}</kbd> textbox.</li>
|
||||
{% endif %}
|
||||
<li>If you have a token, check the <kbd>Mancer AI</kbd> checkbox and enter your token in the <kbd>Mancer
|
||||
API key</kbd> textbox.
|
||||
</li>
|
||||
<li>Click <kbd>Connect</kbd> to test the connection.</li>
|
||||
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
|
||||
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ default_context_size }}.</li>
|
||||
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>
|
||||
</li>
|
||||
</ol>
|
||||
|
@ -120,13 +140,45 @@
|
|||
<br>
|
||||
|
||||
<div class="info-box">
|
||||
<pre><code class="language-json" style="background-color: white">{{ stats_json|safe }}</code></pre>
|
||||
<h3>Statistics</h3>
|
||||
Proompters:
|
||||
<ul style="margin-top: 5px;">
|
||||
<li class="no-marker">5 minutes: {{ proompters_5_min }}</li>
|
||||
<li class="no-marker">24 hours: {{ proompters_24_hrs }}</li>
|
||||
</ul>
|
||||
</div>
|
||||
<br>
|
||||
|
||||
{% for key, value in model_choices.items() %}
|
||||
<div class="info-box">
|
||||
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} {% if value.backend_count == 1 %}worker{% else %}workers{% endif %}</span></h3>
|
||||
|
||||
{% if value.estimated_wait == 0 and value.estimated_wait >= value.concurrent_gens %}
|
||||
{# There will be a wait if the queue is empty but prompts are processing, but we don't know how long. #}
|
||||
{% set estimated_wait_sec = "less than " + value.estimated_wait|int|string + " seconds" %}
|
||||
{% else %}
|
||||
{% set estimated_wait_sec = value.estimated_wait|int|string + " seconds" %}
|
||||
{% endif %}
|
||||
|
||||
<p>
|
||||
<strong>Estimated Wait Time:</strong> {{ estimated_wait_sec }}<br>
|
||||
Processing: {{ value.processing }}<br>
|
||||
Queued: {{ value.queued }}<br>
|
||||
</p>
|
||||
<p>
|
||||
<strong>Client API URL:</strong> {{ value.client_api }}<br>
|
||||
<strong>Streaming API URL:</strong> {{ value.ws_client_api }}<br>
|
||||
<strong>OpenAI-Compatible API URL:</strong> {{ value.openai_client_api }}
|
||||
</p>
|
||||
<p><strong>Context Size:</strong> {{ value.context_size }}</p>
|
||||
<p><strong>Average Generation Time:</strong> {{ value.avg_generation_time | int }} seconds</p>
|
||||
</div>
|
||||
<br>
|
||||
{% endfor %}
|
||||
</div>
|
||||
<div class="footer">
|
||||
<a href="https://git.evulid.cc/cyberes/local-llm-server" target="_blank">git.evulid.cc/cyberes/local-llm-server</a>
|
||||
</div>
|
||||
<script>hljs.highlightAll();</script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
|
|
Reference in New Issue