Merge cluster to master #3
21
daemon.py
21
daemon.py
|
@ -3,9 +3,14 @@ import sys
|
|||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.cluster.redis_cycle import redis_cycler_db
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.config.load import load_config, parse_backends
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.threader import start_background
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
|
@ -19,16 +24,30 @@ if __name__ == "__main__":
|
|||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
redis_cycler_db.flushall()
|
||||
redis_running_models.flush()
|
||||
|
||||
success, config, msg = load_config(config_path)
|
||||
if not success:
|
||||
print('Failed to load config:', msg)
|
||||
sys.exit(1)
|
||||
|
||||
create_db()
|
||||
|
||||
priority_queue.flush()
|
||||
cluster_config.clear()
|
||||
cluster_config.load(parse_backends(config))
|
||||
|
||||
print('Loading backend stats...')
|
||||
generate_stats()
|
||||
|
||||
start_background()
|
||||
|
||||
redis.set('daemon_started', 1)
|
||||
print('== Daemon Setup Complete ==\n')
|
||||
|
||||
try:
|
||||
while True:
|
||||
time.sleep(3600)
|
||||
except KeyboardInterrupt:
|
||||
redis.set('daemon_started', 0)
|
||||
|
|
|
@ -1,23 +1,34 @@
|
|||
from llm_server.cluster.redis_config_cache import RedisClusterStore
|
||||
from llm_server.cluster.redis_cycle import redis_cycle
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.llm.info import get_info
|
||||
|
||||
|
||||
def test_backend(backend_url: str, mode: str):
|
||||
running_model, err = get_running_model(backend_url, mode)
|
||||
if not running_model:
|
||||
return False
|
||||
return True
|
||||
def test_backend(backend_url: str, test_prompt: bool = False):
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
if test_prompt:
|
||||
data = {
|
||||
"prompt": "Test prompt",
|
||||
"stream": False,
|
||||
"temperature": 0,
|
||||
"max_new_tokens": 16,
|
||||
}
|
||||
success, response, err = generator(data, backend_url, timeout=10)
|
||||
if not success or not response or err:
|
||||
return False, {}
|
||||
i = get_info(backend_url, backend_info['mode'])
|
||||
if not i.get('model'):
|
||||
return False, {}
|
||||
return True, i
|
||||
|
||||
|
||||
def get_backends():
|
||||
cluster_config = RedisClusterStore('cluster_config')
|
||||
backends = cluster_config.all()
|
||||
result = {}
|
||||
for k, v in backends.items():
|
||||
b = cluster_config.get_backend(k)
|
||||
status = b['online']
|
||||
status = b.get('online', False)
|
||||
priority = b['priority']
|
||||
result[k] = {'status': status, 'priority': priority}
|
||||
online_backends = sorted(
|
||||
|
@ -33,30 +44,43 @@ def get_backends():
|
|||
return [url for url, info in online_backends], [url for url, info in offline_backends]
|
||||
|
||||
|
||||
def get_a_cluster_backend():
|
||||
def get_a_cluster_backend(model=None):
|
||||
"""
|
||||
Get a backend from Redis. If there are no online backends, return None.
|
||||
If `model` is not supplied, we will pick one ourself.
|
||||
"""
|
||||
online, offline = get_backends()
|
||||
cycled = redis_cycle('backend_cycler')
|
||||
c = cycled.copy()
|
||||
for i in range(len(cycled)):
|
||||
if cycled[i] in offline:
|
||||
del c[c.index(cycled[i])]
|
||||
if len(c):
|
||||
return c[0]
|
||||
if model:
|
||||
# First, determine if there are multiple backends hosting the same model.
|
||||
backends_hosting_model = [i.decode('utf-8') for i in redis_running_models.smembers(model)]
|
||||
|
||||
# If so, create an iterator for those backends
|
||||
if len(backends_hosting_model):
|
||||
add_backend_cycler(model, backends_hosting_model)
|
||||
cycled = redis_cycle(model)
|
||||
if len(cycled):
|
||||
return cycled[0]
|
||||
else:
|
||||
# No backend hosting that model
|
||||
return None
|
||||
else:
|
||||
online, _ = get_backends()
|
||||
if len(online):
|
||||
return online[0]
|
||||
|
||||
|
||||
def get_backends_from_model(model_name: str):
|
||||
cluster_config = RedisClusterStore('cluster_config')
|
||||
a = cluster_config.all()
|
||||
matches = []
|
||||
for k, v in a.items():
|
||||
if v['online'] and v['running_model'] == model_name:
|
||||
matches.append(k)
|
||||
return matches
|
||||
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
|
||||
|
||||
|
||||
# def verify_context_size(model_name:str):
|
||||
# b = get_backends_from_model(model_name)
|
||||
# for backend_url in b:
|
||||
# backend_info = cluster_config.get_backend(backend_url)
|
||||
# backend_info.get()
|
||||
|
||||
|
||||
def get_running_models():
|
||||
return redis_running_models.keys()
|
||||
|
||||
|
||||
def purge_backend_from_running_models(backend_url: str):
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
import numpy as np
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_a_cluster_backend, get_backends_from_model, get_running_models
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers
|
||||
|
||||
|
||||
# TODO: give this a better name!
|
||||
def get_model_choices(regen: bool = False):
|
||||
if not regen:
|
||||
c = redis.getp('model_choices')
|
||||
if c:
|
||||
return c
|
||||
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
running_models = get_running_models()
|
||||
model_choices = {}
|
||||
for model in running_models:
|
||||
b = get_backends_from_model(model)
|
||||
|
||||
context_size = []
|
||||
avg_gen_per_worker = []
|
||||
for backend_url in b:
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
if backend_info.get('model_config'):
|
||||
context_size.append(backend_info['model_config']['max_position_embeddings'])
|
||||
if backend_info.get('average_generation_elapsed_sec'):
|
||||
avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec'])
|
||||
|
||||
active_gen_workers = get_active_gen_workers(model)
|
||||
proompters_in_queue = priority_queue.len(model)
|
||||
|
||||
if len(avg_gen_per_worker):
|
||||
average_generation_elapsed_sec = np.average(avg_gen_per_worker)
|
||||
else:
|
||||
average_generation_elapsed_sec = 0
|
||||
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
|
||||
|
||||
if proompters_in_queue == 0 and active_gen_workers >= opts.concurrent_gens:
|
||||
# There will be a wait if the queue is empty but prompts are processing, but we don't
|
||||
# know how long.
|
||||
estimated_wait_sec = f"less than {estimated_wait_sec} seconds"
|
||||
else:
|
||||
estimated_wait_sec = f"{estimated_wait_sec} seconds"
|
||||
|
||||
model_choices[model] = {
|
||||
'client_api': f'https://{base_client_api}/v2/{model}',
|
||||
'ws_client_api': f'wss://{base_client_api}/v2/{model}/stream' if opts.enable_streaming else None,
|
||||
'openai_client_api': f'https://{base_client_api}/openai/v2/{model}' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
'backend_count': len(b),
|
||||
'estimated_wait': estimated_wait_sec,
|
||||
'queued': proompters_in_queue,
|
||||
'processing': active_gen_workers,
|
||||
'avg_generation_time': average_generation_elapsed_sec
|
||||
}
|
||||
|
||||
if len(context_size):
|
||||
model_choices[model]['context_size'] = min(context_size)
|
||||
|
||||
model_choices = dict(sorted(model_choices.items()))
|
||||
|
||||
default_backend = get_a_cluster_backend()
|
||||
default_backend_info = cluster_config.get_backend(default_backend)
|
||||
default_context_size = default_backend_info['model_config']['max_position_embeddings']
|
||||
default_average_generation_elapsed_sec = default_backend_info.get('average_generation_elapsed_sec')
|
||||
default_active_gen_workers = redis.get(f'active_gen_workers:{default_backend}', dtype=int, default=0)
|
||||
default_proompters_in_queue = priority_queue.len(default_backend_info['model'])
|
||||
default_estimated_wait_sec = calculate_wait_time(default_average_generation_elapsed_sec, default_proompters_in_queue, default_backend_info['concurrent_gens'], default_active_gen_workers)
|
||||
|
||||
default_backend_dict = {
|
||||
'client_api': f'https://{base_client_api}/v2',
|
||||
'ws_client_api': f'wss://{base_client_api}/v2' if opts.enable_streaming else None,
|
||||
'openai_client_api': f'https://{base_client_api}/openai/v2' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
'estimated_wait': default_estimated_wait_sec,
|
||||
'queued': default_proompters_in_queue,
|
||||
'processing': default_active_gen_workers,
|
||||
'context_size': default_context_size,
|
||||
'hash': default_backend_info['hash'],
|
||||
'model': default_backend_info['model'],
|
||||
'avg_generation_time': default_average_generation_elapsed_sec
|
||||
}
|
||||
|
||||
redis.setp('model_choices', (model_choices, default_backend_dict))
|
||||
|
||||
return model_choices, default_backend_dict
|
|
@ -44,3 +44,6 @@ class RedisClusterStore:
|
|||
return result
|
||||
else:
|
||||
return {}
|
||||
|
||||
# def get(self, name: str):
|
||||
# return self.all().get(name)
|
||||
|
|
|
@ -1,21 +1,35 @@
|
|||
import redis
|
||||
|
||||
r = redis.Redis(host='localhost', port=6379, db=9)
|
||||
redis_cycler_db = redis.Redis(host='localhost', port=6379, db=9)
|
||||
|
||||
|
||||
def redis_cycle(list_name):
|
||||
while True:
|
||||
pipe = r.pipeline()
|
||||
pipe.lpop(list_name)
|
||||
popped_element = pipe.execute()[0]
|
||||
if popped_element is None:
|
||||
return None
|
||||
r.rpush(list_name, popped_element)
|
||||
new_list = r.lrange(list_name, 0, -1)
|
||||
"""
|
||||
Emulates itertools.cycle() but returns the complete shuffled list.
|
||||
:param list_name:
|
||||
:return:
|
||||
"""
|
||||
to_move = redis_cycler_db.rpop(list_name)
|
||||
if not to_move:
|
||||
return []
|
||||
redis_cycler_db.lpush(list_name, to_move)
|
||||
new_list = redis_cycler_db.lrange(list_name, 0, -1)
|
||||
return [x.decode('utf-8') for x in new_list]
|
||||
|
||||
|
||||
def load_backend_cycle(list_name: str, elements: list):
|
||||
r.delete(list_name)
|
||||
for element in elements:
|
||||
r.rpush(list_name, element)
|
||||
def add_backend_cycler(list_name: str, new_elements: list):
|
||||
existing_elements = [i.decode('utf-8') for i in redis_cycler_db.lrange(list_name, 0, -1)]
|
||||
existing_set = set(existing_elements)
|
||||
|
||||
with redis_cycler_db.pipeline() as pipe:
|
||||
# Add elements
|
||||
for element in new_elements:
|
||||
if element not in existing_set:
|
||||
pipe.rpush(list_name, element)
|
||||
|
||||
# Remove elements
|
||||
for element in existing_set:
|
||||
if element not in new_elements:
|
||||
pipe.lrem(list_name, 0, element)
|
||||
|
||||
pipe.execute()
|
||||
|
|
|
@ -1,31 +1,42 @@
|
|||
from datetime import datetime
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
from llm_server.cluster.backend import purge_backend_from_running_models, test_backend
|
||||
from llm_server.cluster.backend import test_backend
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.llm.info import get_running_model
|
||||
|
||||
|
||||
def cluster_worker():
|
||||
counter = 0
|
||||
while True:
|
||||
test_prompt = False
|
||||
if counter % 4 == 0:
|
||||
# Only send a test prompt every 120 seconds.
|
||||
test_prompt = True
|
||||
threads = []
|
||||
for n, v in cluster_config.all().items():
|
||||
thread = Thread(target=check_backend, args=(n, v))
|
||||
thread = Thread(target=check_backend, args=(n, v, test_prompt))
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
time.sleep(15)
|
||||
counter += 1
|
||||
|
||||
|
||||
def check_backend(n, v):
|
||||
# Check if backends are online
|
||||
# TODO: also have test_backend() get the uptime
|
||||
online = test_backend(v['backend_url'], v['mode'])
|
||||
def check_backend(n, v, test_prompt):
|
||||
online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
|
||||
# purge_backend_from_running_models(n)
|
||||
if online:
|
||||
running_model, err = get_running_model(v['backend_url'], v['mode'])
|
||||
if not err:
|
||||
cluster_config.set_backend_value(n, 'running_model', running_model)
|
||||
purge_backend_from_running_models(n)
|
||||
running_model = backend_info['model']
|
||||
for k, v in backend_info.items():
|
||||
cluster_config.set_backend_value(n, k, v)
|
||||
redis_running_models.sadd(running_model, n)
|
||||
else:
|
||||
for model in redis_running_models.keys():
|
||||
redis_running_models.srem(model, n)
|
||||
|
||||
# redis_running_models.srem(backend_info['model'], n)
|
||||
# backend_cycler_store.lrem(backend_info['model'], 1, n)
|
||||
|
||||
cluster_config.set_backend_value(n, 'online', online)
|
||||
|
|
|
@ -34,8 +34,9 @@ config_default_vars = {
|
|||
'openai_moderation_enabled': True,
|
||||
'netdata_root': None,
|
||||
'show_backends': True,
|
||||
'cluster_workers': 30
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
config_required_vars = ['cluster', 'mode', 'llm_middleware_name']
|
||||
|
||||
mode_ui_names = {
|
||||
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),
|
||||
|
|
|
@ -26,7 +26,6 @@ def load_config(config_path):
|
|||
opts.log_prompts = config['log_prompts']
|
||||
opts.concurrent_gens = config['concurrent_gens']
|
||||
opts.frontend_api_client = config['frontend_api_client']
|
||||
opts.context_size = config['token_limit']
|
||||
opts.show_num_prompts = config['show_num_prompts']
|
||||
opts.show_uptime = config['show_uptime']
|
||||
opts.cluster = config['cluster']
|
||||
|
@ -53,6 +52,7 @@ def load_config(config_path):
|
|||
opts.openai_silent_trim = config['openai_silent_trim']
|
||||
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
||||
opts.show_backends = config['show_backends']
|
||||
opts.cluster_workers = config['cluster_workers']
|
||||
|
||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||
|
|
|
@ -9,17 +9,18 @@ from flask_caching import Cache
|
|||
from redis import Redis
|
||||
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
|
||||
|
||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
|
||||
ONE_MONTH_SECONDS = 2678000
|
||||
|
||||
|
||||
class RedisCustom:
|
||||
class RedisCustom(Redis):
|
||||
"""
|
||||
A wrapper class to set prefixes to keys.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, **kwargs):
|
||||
super().__init__()
|
||||
self.redis = Redis(**kwargs)
|
||||
self.prefix = prefix
|
||||
try:
|
||||
|
@ -108,6 +109,9 @@ class RedisCustom:
|
|||
):
|
||||
return self.redis.hincrby(self._key(name), key, amount)
|
||||
|
||||
def zcard(self, name: KeyT):
|
||||
return self.redis.zcard(self._key(name))
|
||||
|
||||
def hdel(self, name: str, *keys: List):
|
||||
return self.redis.hdel(self._key(name), *keys)
|
||||
|
||||
|
@ -129,6 +133,9 @@ class RedisCustom:
|
|||
):
|
||||
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
|
||||
|
||||
def lpush(self, name: str, *values: FieldT):
|
||||
return self.redis.lpush(self._key(name), *values)
|
||||
|
||||
def hset(
|
||||
self,
|
||||
name: str,
|
||||
|
@ -164,6 +171,18 @@ class RedisCustom:
|
|||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
return self.redis.pipeline(transaction, shard_hint)
|
||||
|
||||
def smembers(self, name: str):
|
||||
return self.redis.smembers(self._key(name))
|
||||
|
||||
def spop(self, name: str, count: Optional[int] = None):
|
||||
return self.redis.spop(self._key(name), count)
|
||||
|
||||
def rpoplpush(self, src, dst):
|
||||
return self.redis.rpoplpush(src, dst)
|
||||
|
||||
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
|
||||
return self.redis.zpopmin(self._key(name), count)
|
||||
|
||||
def exists(self, *names: KeyT):
|
||||
n = []
|
||||
for name in names:
|
||||
|
@ -196,5 +215,13 @@ class RedisCustom:
|
|||
self.redis.delete(key)
|
||||
return flushed
|
||||
|
||||
def flushall(self, asynchronous: bool = ..., **kwargs) -> bool:
|
||||
self.flush()
|
||||
return True
|
||||
|
||||
def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool:
|
||||
self.flush()
|
||||
return True
|
||||
|
||||
|
||||
redis = RedisCustom('local_llm')
|
||||
|
|
|
@ -5,14 +5,14 @@ from threading import Thread
|
|||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.llm.vllm import tokenize
|
||||
|
||||
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens: int = None, is_error: bool = False):
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens: int = None, is_error: bool = False):
|
||||
def background_task():
|
||||
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens, is_error
|
||||
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error
|
||||
# Try not to shove JSON into the database.
|
||||
if isinstance(response, dict) and response.get('results'):
|
||||
response = response['results'][0]['text']
|
||||
|
@ -23,10 +23,10 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
except:
|
||||
pass
|
||||
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt, backend_url)
|
||||
if not is_error:
|
||||
if not response_tokens:
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
response_tokens = llm_server.llm.get_token_count(response, backend_url)
|
||||
else:
|
||||
response_tokens = None
|
||||
|
||||
|
@ -47,7 +47,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
if token:
|
||||
increment_token_uses(token)
|
||||
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
running_model = cluster_config.get_backend(backend_url).get('model')
|
||||
timestamp = int(time.time())
|
||||
cursor = database.cursor()
|
||||
try:
|
||||
|
@ -56,7 +56,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(ip, token, running_model, opts.mode, cluster_backend, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
(ip, token, running_model, opts.mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
|
|
|
@ -2,10 +2,10 @@ from llm_server.llm import oobabooga, vllm
|
|||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
def get_token_count(prompt: str):
|
||||
def get_token_count(prompt: str, backend_url: str):
|
||||
backend_mode = redis.get('backend_mode', dtype=str)
|
||||
if backend_mode == 'vllm':
|
||||
return vllm.tokenize(prompt)
|
||||
return vllm.tokenize(prompt, backend_url)
|
||||
elif backend_mode == 'ooba':
|
||||
return oobabooga.tokenize(prompt)
|
||||
else:
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
from llm_server import opts
|
||||
|
||||
|
||||
def generator(request_json_body, cluster_backend):
|
||||
def generator(request_json_body, cluster_backend, timeout: int = None):
|
||||
if opts.mode == 'oobabooga':
|
||||
# from .oobabooga.generate import generate
|
||||
# return generate(request_json_body)
|
||||
raise NotImplementedError
|
||||
elif opts.mode == 'vllm':
|
||||
from .vllm.generate import generate
|
||||
r = generate(request_json_body, cluster_backend)
|
||||
return r
|
||||
return generate(request_json_body, cluster_backend, timeout=timeout)
|
||||
else:
|
||||
raise Exception
|
||||
|
|
|
@ -20,3 +20,18 @@ def get_running_model(backend_url: str, mode: str):
|
|||
return False, e
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
||||
def get_info(backend_url: str, mode: str):
|
||||
if mode == 'ooba':
|
||||
return {}
|
||||
# raise NotImplementedError
|
||||
elif mode == 'vllm':
|
||||
try:
|
||||
r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout)
|
||||
j = r.json()
|
||||
except Exception as e:
|
||||
return {}
|
||||
return j
|
||||
else:
|
||||
raise Exception
|
||||
|
|
|
@ -3,13 +3,17 @@ from typing import Tuple, Union
|
|||
import flask
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.llm import get_token_count
|
||||
|
||||
|
||||
class LLMBackend:
|
||||
_default_params: dict
|
||||
|
||||
def __init__(self, backend_url: str):
|
||||
self.backend_url = backend_url
|
||||
|
||||
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -38,8 +42,9 @@ class LLMBackend:
|
|||
return True, None
|
||||
|
||||
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = get_token_count(prompt)
|
||||
if prompt_len > opts.context_size - 10:
|
||||
prompt_len = get_token_count(prompt, self.backend_url)
|
||||
token_limit = cluster_config.get_backend(self.backend_url)['model_config']['max_position_embeddings']
|
||||
if prompt_len > token_limit - 10:
|
||||
model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str)
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {model_name}). Please lower your context size'
|
||||
return True, None
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from flask import jsonify
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
from ..llm_backend import LLMBackend
|
||||
from ...database.database import log_prompt
|
||||
from ...helpers import safe_list_get
|
||||
from llm_server.custom_redis import redis
|
||||
from ...routes.helpers.client import format_sillytavern_err
|
||||
from ...routes.helpers.http import validate_json
|
||||
|
||||
|
@ -33,7 +33,7 @@ class OobaboogaBackend(LLMBackend):
|
|||
error_msg = 'Unknown error.'
|
||||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = format_sillytavern_err(error_msg, 'error')
|
||||
backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url)
|
||||
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
|
@ -50,7 +50,8 @@ class OobaboogaBackend(LLMBackend):
|
|||
backend_err = True
|
||||
backend_response = format_sillytavern_err(
|
||||
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
|
||||
'error')
|
||||
error_type='error',
|
||||
backend_url=self.backend_url)
|
||||
response_json_body['results'][0]['text'] = backend_response
|
||||
|
||||
if not backend_err:
|
||||
|
@ -61,7 +62,7 @@ class OobaboogaBackend(LLMBackend):
|
|||
**response_json_body
|
||||
}), 200
|
||||
else:
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url)
|
||||
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
|
|
|
@ -24,57 +24,6 @@ def prepare_json(json_data: dict):
|
|||
return json_data
|
||||
|
||||
|
||||
def transform_to_text(json_request, api_response):
|
||||
"""
|
||||
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
|
||||
:param json_request:
|
||||
:param api_response:
|
||||
:return:
|
||||
"""
|
||||
prompt = transform_prompt_to_text(json_request['messages'])
|
||||
text = ''
|
||||
finish_reason = None
|
||||
for line in api_response.split('\n'):
|
||||
if line.startswith('data:'):
|
||||
try:
|
||||
data = json.loads(line[5:].strip())
|
||||
except json.decoder.JSONDecodeError:
|
||||
break
|
||||
if 'choices' in data:
|
||||
for choice in data['choices']:
|
||||
if 'delta' in choice and 'content' in choice['delta']:
|
||||
text += choice['delta']['content']
|
||||
if data['choices'][0]['finish_reason']:
|
||||
finish_reason = data['choices'][0]['finish_reason']
|
||||
|
||||
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
|
||||
completion_tokens = len(llm_server.llm.get_token_count(text))
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
|
||||
return {
|
||||
"id": str(uuid4()),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": running_model,
|
||||
"usage": {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens
|
||||
},
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
},
|
||||
"finish_reason": finish_reason,
|
||||
"index": 0
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def transform_prompt_to_text(prompt: list):
|
||||
text = ''
|
||||
for item in prompt:
|
||||
|
@ -82,26 +31,26 @@ def transform_prompt_to_text(prompt: list):
|
|||
return text.strip('\n')
|
||||
|
||||
|
||||
def handle_blocking_request(json_data: dict, cluster_backend):
|
||||
def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10):
|
||||
try:
|
||||
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
|
||||
# print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
|
||||
return False, None, 'Request to backend timed out'
|
||||
except Exception as e:
|
||||
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
return False, None, 'Request to backend encountered error'
|
||||
if r.status_code != 200:
|
||||
print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
|
||||
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
|
||||
return False, r, f'Backend returned {r.status_code}'
|
||||
return True, r, None
|
||||
|
||||
|
||||
def generate(json_data: dict, cluster_backend):
|
||||
def generate(json_data: dict, cluster_backend, timeout: int = None):
|
||||
if json_data.get('stream'):
|
||||
try:
|
||||
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
|
||||
except Exception as e:
|
||||
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
return False
|
||||
else:
|
||||
return handle_blocking_request(json_data, cluster_backend)
|
||||
return handle_blocking_request(json_data, cluster_backend, timeout=timeout)
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/vllm-project/vllm" target="_blank">vllm</a> and not all Oobabooga parameters are supported.</p>
|
||||
<strong>Supported Parameters:</strong>
|
||||
<ul>
|
||||
|
|
|
@ -2,19 +2,21 @@ import requests
|
|||
import tiktoken
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
def tokenize(prompt: str) -> int:
|
||||
def tokenize(prompt: str, backend_url: str) -> int:
|
||||
if not prompt:
|
||||
# The tokenizers have issues when the prompt is None.
|
||||
return 0
|
||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||
token_limit = cluster_config.get_backend(backend_url)['model_config']['max_position_embeddings']
|
||||
|
||||
# First we tokenize it locally to determine if it's worth sending it to the backend.
|
||||
initial_estimate = len(tokenizer.encode(prompt))
|
||||
if initial_estimate <= opts.context_size + 200:
|
||||
if initial_estimate <= token_limit + 200:
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
r = requests.post(f'{backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
j = r.json()
|
||||
return j['length']
|
||||
except Exception as e:
|
||||
|
|
|
@ -20,7 +20,7 @@ class VLLMBackend(LLMBackend):
|
|||
backend_response = ''
|
||||
|
||||
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
|
||||
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
|
||||
response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
|
||||
|
||||
return jsonify({'results': [{'text': backend_response}]}), 200
|
||||
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
concurrent_gens = 3
|
||||
mode = 'oobabooga'
|
||||
backend_url = None
|
||||
context_size = 5555
|
||||
max_new_tokens = 500
|
||||
auth_required = False
|
||||
log_prompts = False
|
||||
|
@ -38,3 +37,4 @@ openai_silent_trim = False
|
|||
openai_moderation_enabled = True
|
||||
cluster = {}
|
||||
show_backends = True
|
||||
cluster_workers = 30
|
||||
|
|
|
@ -1,21 +1,9 @@
|
|||
import sys
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
|
||||
def server_startup(s):
|
||||
if not redis.get('daemon_started', dtype=bool):
|
||||
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||
sys.exit(1)
|
||||
|
||||
# Flush the RedisPriorityQueue database.
|
||||
queue_redis = Redis(host='localhost', port=6379, db=15)
|
||||
for key in queue_redis.scan_iter('*'):
|
||||
queue_redis.delete(key)
|
||||
|
||||
# Cache the initial stats
|
||||
print('Loading backend stats...')
|
||||
generate_stats()
|
||||
|
|
|
@ -2,13 +2,14 @@ from llm_server.cluster.cluster_config import cluster_config
|
|||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
def format_sillytavern_err(msg: str, backend_url: str, level: str = 'info'):
|
||||
cluster_backend_hash = cluster_config.get_backend_handler(backend_url)['hash']
|
||||
def format_sillytavern_err(msg: str, backend_url: str = 'none', error_type: str = 'info'):
|
||||
cluster_backend_hash = cluster_config.get_backend(backend_url)['hash']
|
||||
http_host = redis.get('http_host', dtype=str)
|
||||
return f"""```
|
||||
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
|
||||
-> {level.upper()} <-
|
||||
-> {error_type.upper()} <-
|
||||
{msg}
|
||||
|
||||
BACKEND HASH: {cluster_backend_hash}
|
||||
```
|
||||
```
|
||||
BACKEND: {cluster_backend_hash}
|
||||
```"""
|
||||
|
|
|
@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||
backend_response = self.handle_error(msg)
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.cluster_backend, is_error=True)
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
|
||||
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
|
@ -40,7 +40,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
# TODO: how to format this
|
||||
response_msg = error_msg
|
||||
else:
|
||||
response_msg = format_sillytavern_err(error_msg, error_type, self.cluster_backend)
|
||||
response_msg = format_sillytavern_err(error_msg, error_type=error_type, backend_url=self.backend_url)
|
||||
|
||||
return jsonify({
|
||||
'results': [{'text': response_msg}]
|
||||
|
|
|
@ -6,7 +6,7 @@ from uuid import uuid4
|
|||
from redis import Redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
|
||||
|
||||
def increment_ip_count(client_ip: str, redis_key):
|
||||
|
@ -20,12 +20,12 @@ def decrement_ip_count(client_ip: str, redis_key):
|
|||
|
||||
|
||||
class RedisPriorityQueue:
|
||||
def __init__(self):
|
||||
self.redis = Redis(host='localhost', port=6379, db=15)
|
||||
def __init__(self, name: str = 'priority_queue', db: int = 12):
|
||||
self.redis = RedisCustom(name, db=db)
|
||||
self.pubsub = self.redis.pubsub()
|
||||
self.pubsub.subscribe('events')
|
||||
|
||||
def put(self, item, priority):
|
||||
def put(self, item, priority, selected_model):
|
||||
event = DataEvent()
|
||||
|
||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||
|
@ -36,7 +36,7 @@ class RedisPriorityQueue:
|
|||
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
||||
return None # reject the request
|
||||
|
||||
self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority})
|
||||
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority})
|
||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||
return event
|
||||
|
||||
|
@ -61,12 +61,23 @@ class RedisPriorityQueue:
|
|||
def __len__(self):
|
||||
return self.redis.zcard('queue')
|
||||
|
||||
def len(self, model_name):
|
||||
count = 0
|
||||
for key in self.redis.zrange('queue', 0, -1):
|
||||
item = json.loads(key)
|
||||
if item[2] == model_name:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def get_queued_ip_count(self, client_ip: str):
|
||||
q = self.redis.hget('queued_ip_count', client_ip)
|
||||
if not q:
|
||||
return 0
|
||||
return 0
|
||||
|
||||
def flush(self):
|
||||
self.redis.flush()
|
||||
|
||||
|
||||
class DataEvent:
|
||||
def __init__(self, event_id=None):
|
||||
|
@ -87,12 +98,16 @@ class DataEvent:
|
|||
priority_queue = RedisPriorityQueue()
|
||||
|
||||
|
||||
def incr_active_workers():
|
||||
redis.incr('active_gen_workers')
|
||||
def incr_active_workers(selected_model: str, backend_url: str):
|
||||
redis.incr(f'active_gen_workers:{selected_model}')
|
||||
redis.incr(f'active_gen_workers:{backend_url}')
|
||||
|
||||
|
||||
def decr_active_workers():
|
||||
redis.decr('active_gen_workers')
|
||||
new_count = redis.get('active_gen_workers', 0, dtype=int)
|
||||
if new_count < 0:
|
||||
redis.set('active_gen_workers', 0)
|
||||
def decr_active_workers(selected_model: str, backend_url: str):
|
||||
redis.decr(f'active_gen_workers:{selected_model}')
|
||||
if redis.get(f'active_gen_workers:{selected_model}', 0, dtype=int) < 0:
|
||||
redis.set(f'active_gen_workers:{selected_model}', 0)
|
||||
|
||||
redis.decr(f'active_gen_workers:{backend_url}')
|
||||
if redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) < 0:
|
||||
redis.set(f'active_gen_workers:{backend_url}', 0)
|
||||
|
|
|
@ -15,13 +15,13 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
|||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.auth import parse_token
|
||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
|
||||
|
||||
DEFAULT_PRIORITY = 9999
|
||||
|
||||
|
||||
class RequestHandler:
|
||||
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
|
||||
def __init__(self, incoming_request: flask.Request, selected_model: str, incoming_json: Union[dict, str] = None):
|
||||
self.request = incoming_request
|
||||
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
|
||||
|
||||
|
@ -37,11 +37,12 @@ class RequestHandler:
|
|||
self.client_ip = self.get_client_ip()
|
||||
self.token = self.get_auth_token()
|
||||
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
|
||||
self.cluster_backend = get_a_cluster_backend()
|
||||
self.cluster_backend_info = cluster_config.get_backend(self.cluster_backend)
|
||||
self.backend = get_backend_handler(self.cluster_backend)
|
||||
self.backend_url = get_a_cluster_backend(selected_model)
|
||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
|
||||
self.parameters = None
|
||||
self.used = False
|
||||
self.selected_model = selected_model
|
||||
redis.zadd('recent_prompters', {self.client_ip: time.time()})
|
||||
|
||||
def get_auth_token(self):
|
||||
|
@ -123,7 +124,7 @@ class RequestHandler:
|
|||
backend_response = self.handle_error(combined_error_message, 'Validation Error')
|
||||
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.cluster_backend, is_error=True)
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
|
||||
return False, backend_response
|
||||
return True, (None, 0)
|
||||
|
||||
|
@ -135,14 +136,16 @@ class RequestHandler:
|
|||
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
|
||||
if not request_valid:
|
||||
return (False, None, None, 0), invalid_response
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.cluster_backend), self.token_priority)
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.backend_url), self.token_priority, self.selected_model)
|
||||
else:
|
||||
event = None
|
||||
|
||||
if not event:
|
||||
return (False, None, None, 0), self.handle_ratelimited()
|
||||
|
||||
# TODO: add wait timeout
|
||||
success, response, error_msg = event.wait()
|
||||
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - self.start_time
|
||||
|
||||
|
@ -164,7 +167,7 @@ class RequestHandler:
|
|||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = self.handle_error(error_msg)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
# ===============================================
|
||||
|
@ -184,7 +187,7 @@ class RequestHandler:
|
|||
if return_json_err:
|
||||
error_msg = 'The backend did not return valid JSON.'
|
||||
backend_response = self.handle_error(error_msg)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
# ===============================================
|
||||
|
@ -218,11 +221,11 @@ class RequestHandler:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_backend_handler(mode):
|
||||
def get_backend_handler(mode, backend_url: str):
|
||||
if mode == 'oobabooga':
|
||||
return OobaboogaBackend()
|
||||
return OobaboogaBackend(backend_url)
|
||||
elif mode == 'vllm':
|
||||
return VLLMBackend()
|
||||
return VLLMBackend(backend_url)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from datetime import datetime
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.helpers import round_up_base
|
||||
|
||||
server_start_time = datetime.now()
|
||||
|
||||
|
@ -14,10 +15,32 @@ def get_total_proompts():
|
|||
return count
|
||||
|
||||
|
||||
def get_active_gen_workers():
|
||||
active_gen_workers = redis.get('active_gen_workers')
|
||||
def get_active_gen_workers(selected_model: str = None, ):
|
||||
active_gen_workers = redis.get(f'active_gen_workers:{selected_model}')
|
||||
if active_gen_workers is None:
|
||||
count = 0
|
||||
else:
|
||||
count = int(active_gen_workers)
|
||||
return count
|
||||
|
||||
|
||||
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
|
||||
if active_gen_workers < concurrent_gens:
|
||||
return 0
|
||||
elif active_gen_workers >= concurrent_gens:
|
||||
# Calculate how long it will take to complete the currently running gens and the queued requests.
|
||||
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
|
||||
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
|
||||
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
|
||||
# Add gen_time_calc to the time to account for the currently running generations.
|
||||
# This assumes that all active workers will finish at the same time, which is unlikely.
|
||||
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
|
||||
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
|
||||
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
|
||||
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
|
||||
elif proompters_in_queue == 0 and active_gen_workers == 0:
|
||||
# No queue, no workers
|
||||
return 0
|
||||
else:
|
||||
# No queue
|
||||
return gen_time_calc
|
||||
|
|
|
@ -3,18 +3,20 @@ import traceback
|
|||
from flask import jsonify, request
|
||||
|
||||
from . import bp
|
||||
from ..helpers.client import format_sillytavern_err
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ...cluster.backend import get_a_cluster_backend
|
||||
from ...cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
@bp.route('/generate', methods=['POST'])
|
||||
def generate():
|
||||
@bp.route('/v1/generate', methods=['POST'])
|
||||
@bp.route('/<model_name>/v1/generate', methods=['POST'])
|
||||
def generate(model_name=None):
|
||||
request_valid_json, request_json_body = validate_json(request)
|
||||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
|
||||
else:
|
||||
handler = OobaRequestHandler(request)
|
||||
handler = OobaRequestHandler(request, model_name)
|
||||
try:
|
||||
return handler.handle_request()
|
||||
except Exception:
|
||||
|
|
|
@ -2,74 +2,32 @@ import time
|
|||
from datetime import datetime
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_a_cluster_backend, test_backend
|
||||
from llm_server.cluster.backend import get_a_cluster_backend
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.cluster.model_choices import get_model_choices
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import get_distinct_ips_24h, sum_column
|
||||
from llm_server.helpers import deep_sort, round_up_base
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
|
||||
|
||||
|
||||
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
|
||||
if active_gen_workers < concurrent_gens:
|
||||
return 0
|
||||
elif active_gen_workers >= concurrent_gens:
|
||||
# Calculate how long it will take to complete the currently running gens and the queued requests.
|
||||
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
|
||||
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
|
||||
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
|
||||
# Add gen_time_calc to the time to account for the currently running generations.
|
||||
# This assumes that all active workers will finish at the same time, which is unlikely.
|
||||
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
|
||||
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
|
||||
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
|
||||
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
|
||||
elif proompters_in_queue == 0 and active_gen_workers == 0:
|
||||
# No queue, no workers
|
||||
return 0
|
||||
else:
|
||||
# No queue
|
||||
return gen_time_calc
|
||||
from llm_server.helpers import deep_sort
|
||||
from llm_server.routes.stats import get_total_proompts, server_start_time
|
||||
|
||||
|
||||
def generate_stats(regen: bool = False):
|
||||
if not regen:
|
||||
c = redis.get('proxy_stats', dtype=dict)
|
||||
c = redis.getp('proxy_stats')
|
||||
if c:
|
||||
return c
|
||||
|
||||
default_backend_url = get_a_cluster_backend()
|
||||
default_backend_info = cluster_config.get_backend(default_backend_url)
|
||||
if not default_backend_info.get('mode'):
|
||||
# TODO: remove
|
||||
print('DAEMON NOT FINISHED STARTING')
|
||||
return
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
|
||||
average_generation_elapsed_sec = redis.get('average_generation_elapsed_sec', 0)
|
||||
|
||||
online = test_backend(default_backend_url, default_backend_info['mode'])
|
||||
if online:
|
||||
running_model, err = get_running_model(default_backend_url, default_backend_info['mode'])
|
||||
cluster_config.set_backend_value(default_backend_url, 'running_model', running_model)
|
||||
else:
|
||||
running_model = None
|
||||
|
||||
active_gen_workers = get_active_gen_workers()
|
||||
proompters_in_queue = len(priority_queue)
|
||||
|
||||
# This is so wildly inaccurate it's disabled.
|
||||
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
|
||||
|
||||
# TODO: make this for the currently selected backend
|
||||
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
|
||||
|
||||
output = {
|
||||
'default': {
|
||||
'model': running_model,
|
||||
'backend': default_backend_info['hash'],
|
||||
'model': default_backend_info['model'],
|
||||
'backend': default_backend_url,
|
||||
},
|
||||
'stats': {
|
||||
'proompters': {
|
||||
|
@ -78,21 +36,14 @@ def generate_stats(regen: bool = False):
|
|||
},
|
||||
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
|
||||
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
|
||||
'average_generation_elapsed_sec': int(average_generation_elapsed_sec),
|
||||
# 'estimated_avg_tps': estimated_avg_tps,
|
||||
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
|
||||
'num_backends': len(cluster_config.all()) if opts.show_backends else None,
|
||||
},
|
||||
'online': online,
|
||||
'endpoints': {
|
||||
'blocking': f'https://{base_client_api}',
|
||||
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||
},
|
||||
'queue': {
|
||||
'processing': active_gen_workers,
|
||||
'queued': proompters_in_queue,
|
||||
'estimated_wait_sec': int(estimated_wait_sec),
|
||||
},
|
||||
'timestamp': int(time.time()),
|
||||
'config': {
|
||||
'gatekeeper': 'none' if opts.auth_required is False else 'token',
|
||||
|
@ -106,42 +57,30 @@ def generate_stats(regen: bool = False):
|
|||
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
|
||||
}
|
||||
|
||||
# TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself
|
||||
|
||||
if opts.show_backends:
|
||||
for backend_url, v in cluster_config.all().items():
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
if not backend_info['online']:
|
||||
continue
|
||||
|
||||
# TODO: have this fetch the data from VLLM which will display GPU utalization
|
||||
# if opts.netdata_root:
|
||||
# netdata_stats = {}
|
||||
# power_states = get_power_states()
|
||||
# for gpu, power_state in power_states.items():
|
||||
# netdata_stats[gpu] = {
|
||||
# 'power_state': power_state,
|
||||
# # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
|
||||
# }
|
||||
# else:
|
||||
# netdata_stats = {}
|
||||
netdata_stats = {}
|
||||
|
||||
# TODO: use value returned by VLLM backend here
|
||||
# backend_uptime = int((datetime.now() - backend_info['start_time']).total_seconds()) if opts.show_uptime else None
|
||||
backend_uptime = -1
|
||||
|
||||
backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None
|
||||
output['backend_info'][backend_info['hash']] = {
|
||||
'uptime': backend_uptime,
|
||||
# 'context_size': opts.context_size,
|
||||
'model': opts.manual_model_name if opts.manual_model_name else backend_info.get('running_model', 'ERROR'),
|
||||
'max_tokens': backend_info['model_config']['max_position_embeddings'],
|
||||
'model': backend_info['model'],
|
||||
'mode': backend_info['mode'],
|
||||
'nvidia': netdata_stats
|
||||
'nvidia': backend_info['nvidia'],
|
||||
}
|
||||
else:
|
||||
output['backend_info'] = {}
|
||||
|
||||
output['default'] = get_model_choices(regen=True)[1]
|
||||
|
||||
result = deep_sort(output)
|
||||
|
||||
# It may take a bit to get the base client API, so don't cache until then.
|
||||
if base_client_api:
|
||||
redis.set_dict('proxy_stats', result) # Cache with no expiry
|
||||
redis.setp('proxy_stats', result)
|
||||
|
||||
return result
|
||||
|
|
|
@ -10,13 +10,9 @@ from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends
|
|||
from ...cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
@bp.route('/model', methods=['GET'])
|
||||
@bp.route('/<model_name>/model', methods=['GET'])
|
||||
@bp.route('/v1/model', methods=['GET'])
|
||||
@bp.route('/<model_name>/v1/model', methods=['GET'])
|
||||
def get_model(model_name=None):
|
||||
if not model_name:
|
||||
b = get_a_cluster_backend()
|
||||
model_name = cluster_config.get_backend(b)['running_model']
|
||||
|
||||
# We will manage caching ourself since we don't want to cache
|
||||
# when the backend is down. Also, Cloudflare won't cache 500 errors.
|
||||
cache_key = 'model_cache::' + request.url
|
||||
|
@ -25,6 +21,9 @@ def get_model(model_name=None):
|
|||
if cached_response:
|
||||
return cached_response
|
||||
|
||||
if not model_name:
|
||||
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
|
||||
|
||||
if not is_valid_model(model_name):
|
||||
response = jsonify({
|
||||
'code': 400,
|
||||
|
@ -32,7 +31,6 @@ def get_model(model_name=None):
|
|||
}), 400
|
||||
else:
|
||||
num_backends = len(get_backends_from_model(model_name))
|
||||
|
||||
response = jsonify({
|
||||
'result': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||
'model_backend_count': num_backends,
|
||||
|
@ -47,7 +45,8 @@ def get_model(model_name=None):
|
|||
@requires_auth
|
||||
def get_backend():
|
||||
online, offline = get_backends()
|
||||
result = []
|
||||
result = {}
|
||||
for i in online + offline:
|
||||
result.append(cluster_config.get_backend(i))
|
||||
info = cluster_config.get_backend(i)
|
||||
result[info['hash']] = info
|
||||
return jsonify(result), 200
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import threading
|
||||
import time
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
|
||||
|
@ -9,12 +9,16 @@ from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip
|
|||
|
||||
def worker():
|
||||
while True:
|
||||
need_to_wait()
|
||||
(request_json_body, client_ip, token, parameters, cluster_backend), event_id = priority_queue.get()
|
||||
need_to_wait()
|
||||
(request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get()
|
||||
if not selected_model:
|
||||
selected_model = cluster_config.get_backend(backend_url)['model']
|
||||
|
||||
# This wait time is "invisible", meaning the worker may as
|
||||
# well be still waiting to get an item from the queue.
|
||||
need_to_wait(backend_url)
|
||||
|
||||
increment_ip_count(client_ip, 'processing_ips')
|
||||
incr_active_workers()
|
||||
incr_active_workers(selected_model, backend_url)
|
||||
|
||||
if not request_json_body:
|
||||
# This was a dummy request from the websocket handler.
|
||||
|
@ -22,12 +26,12 @@ def worker():
|
|||
continue
|
||||
|
||||
try:
|
||||
success, response, error_msg = generator(request_json_body, cluster_backend)
|
||||
success, response, error_msg = generator(request_json_body, backend_url)
|
||||
event = DataEvent(event_id)
|
||||
event.set((success, response, error_msg))
|
||||
finally:
|
||||
decrement_ip_count(client_ip, 'processing_ips')
|
||||
decr_active_workers()
|
||||
decr_active_workers(selected_model, backend_url)
|
||||
|
||||
|
||||
def start_workers(num_workers: int):
|
||||
|
@ -40,11 +44,12 @@ def start_workers(num_workers: int):
|
|||
print(f'Started {i} inference workers.')
|
||||
|
||||
|
||||
def need_to_wait():
|
||||
def need_to_wait(backend_url: str):
|
||||
# We need to check the number of active workers since the streaming endpoint may be doing something.
|
||||
active_workers = redis.get('active_gen_workers', 0, dtype=int)
|
||||
active_workers = redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int)
|
||||
concurrent_gens = cluster_config.get_backend(backend_url).get('concurrent_gens', 1)
|
||||
s = time.time()
|
||||
while active_workers >= opts.concurrent_gens:
|
||||
while active_workers >= concurrent_gens:
|
||||
time.sleep(0.01)
|
||||
e = time.time()
|
||||
if e - s > 0.5:
|
||||
|
|
|
@ -5,7 +5,7 @@ from llm_server.cluster.backend import get_a_cluster_backend, get_backends
|
|||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.llm.info import get_info, get_running_model
|
||||
|
||||
|
||||
def main_background_thread():
|
||||
|
@ -14,8 +14,9 @@ def main_background_thread():
|
|||
for backend_url in online:
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
backend_mode = backend_info['mode']
|
||||
running_model, err = get_running_model(backend_url, backend_mode)
|
||||
if err:
|
||||
backend_info = get_info(backend_url, backend_mode)
|
||||
running_model = backend_info.get('model')
|
||||
if not running_model:
|
||||
continue
|
||||
|
||||
average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode)
|
||||
|
@ -25,21 +26,6 @@ def main_background_thread():
|
|||
cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens)
|
||||
if average_generation_elapsed_sec and average_output_tokens:
|
||||
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
|
||||
|
||||
default_backend_url = get_a_cluster_backend()
|
||||
default_backend_info = cluster_config.get_backend(default_backend_url)
|
||||
default_backend_mode = default_backend_info['mode']
|
||||
default_running_model, err = get_running_model(default_backend_url, default_backend_mode)
|
||||
if err:
|
||||
continue
|
||||
|
||||
default_average_generation_elapsed_sec, default_average_output_tokens, default_estimated_avg_tps = calc_stats_for_backend(default_running_model, default_running_model, default_backend_mode)
|
||||
if default_average_generation_elapsed_sec:
|
||||
redis.set('average_generation_elapsed_sec', default_average_generation_elapsed_sec)
|
||||
if default_average_output_tokens:
|
||||
redis.set('average_output_tokens', default_average_output_tokens)
|
||||
if default_average_generation_elapsed_sec and default_average_output_tokens:
|
||||
redis.set('estimated_avg_tps', default_estimated_avg_tps)
|
||||
time.sleep(30)
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import time
|
||||
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.routes.queue import priority_queue
|
||||
|
||||
|
@ -17,9 +18,11 @@ if not logger.handlers:
|
|||
def console_printer():
|
||||
time.sleep(3)
|
||||
while True:
|
||||
processing = redis.hkeys('processing_ips')
|
||||
processing = redis.keys('active_gen_workers:http*') # backends always start with http
|
||||
processing_count = 0
|
||||
for ip in processing:
|
||||
processing_count += int(redis.hget('processing_ips', ip))
|
||||
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}')
|
||||
if len(processing):
|
||||
for k in processing:
|
||||
processing_count += redis.get(k, default=0, dtype=int)
|
||||
backends = [k for k, v in cluster_config.all().items() if v['online']]
|
||||
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
|
||||
time.sleep(10)
|
||||
|
|
|
@ -15,11 +15,11 @@ from llm_server.workers.recenter import recent_prompters_thread
|
|||
def cache_stats():
|
||||
while True:
|
||||
generate_stats(regen=True)
|
||||
time.sleep(1)
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def start_background():
|
||||
start_workers(opts.concurrent_gens)
|
||||
start_workers(opts.cluster_workers)
|
||||
|
||||
t = Thread(target=main_background_thread)
|
||||
t.daemon = True
|
||||
|
|
|
@ -1,20 +1,17 @@
|
|||
flask~=2.3.3
|
||||
flask_cors
|
||||
pyyaml~=6.0.1
|
||||
flask_caching
|
||||
requests~=2.31.0
|
||||
tiktoken~=0.5.0
|
||||
gunicorn
|
||||
gevent~=23.9.0.post1
|
||||
async-timeout
|
||||
flask-sock
|
||||
uvicorn~=0.23.2
|
||||
fastapi~=0.103.1
|
||||
torch~=2.0.1
|
||||
PyMySQL~=1.1.0
|
||||
DBUtils~=3.0.3
|
||||
simplejson~=3.19.1
|
||||
websockets~=11.0.3
|
||||
basicauth~=1.0.0
|
||||
openai~=0.28.0
|
||||
urllib3~=2.0.4
|
||||
flask-sock==0.6.0
|
||||
gunicorn==21.2.0
|
||||
redis==5.0.1
|
||||
git+https://github.com/vllm-project/vllm
|
69
server.py
69
server.py
|
@ -1,5 +1,3 @@
|
|||
from llm_server.cluster.cluster_config import cluster_config
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
|
@ -14,10 +12,10 @@ from pathlib import Path
|
|||
import simplejson as json
|
||||
from flask import Flask, jsonify, render_template, request
|
||||
|
||||
from llm_server.cluster.backend import get_a_cluster_backend, get_backends
|
||||
from llm_server.cluster.redis_cycle import load_backend_cycle
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.cluster.model_choices import get_model_choices
|
||||
from llm_server.config.config import mode_ui_names
|
||||
from llm_server.config.load import load_config, parse_backends
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.pre_fork import server_startup
|
||||
|
@ -26,10 +24,7 @@ from llm_server.routes.server_error import handle_server_error
|
|||
from llm_server.routes.v1 import bp
|
||||
from llm_server.sock import init_socketio
|
||||
|
||||
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
|
||||
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
|
||||
# TODO: implement background thread to test backends via sending test prompts
|
||||
# TODO: if backend fails request, mark it as down
|
||||
# TODO: per-backend workers
|
||||
# TODO: allow setting concurrent gens per-backend
|
||||
# TODO: set the max tokens to that of the lowest backend
|
||||
# TODO: implement RRD backend loadbalancer option
|
||||
|
@ -42,6 +37,7 @@ from llm_server.sock import init_socketio
|
|||
# TODO: have VLLM report context size, uptime
|
||||
|
||||
# Lower priority
|
||||
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
|
||||
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
|
||||
# TODO: the estiamted wait time lags behind the stats
|
||||
# TODO: simulate OpenAI error messages regardless of endpoint
|
||||
|
@ -69,12 +65,11 @@ from llm_server.helpers import auto_set_base_client_api
|
|||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from llm_server.llm import redis
|
||||
from llm_server.routes.stats import get_active_gen_workers
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
app = Flask(__name__)
|
||||
init_socketio(app)
|
||||
app.register_blueprint(bp, url_prefix='/api/v1/')
|
||||
app.register_blueprint(bp, url_prefix='/api/v2/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
flask_cache.init_app(app)
|
||||
flask_cache.clear()
|
||||
|
@ -94,37 +89,23 @@ if not success:
|
|||
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
||||
create_db()
|
||||
|
||||
cluster_config.clear()
|
||||
cluster_config.load(parse_backends(config))
|
||||
on, off = get_backends()
|
||||
load_backend_cycle('backend_cycler', on + off)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
@app.route('/api')
|
||||
@app.route('/api/openai')
|
||||
@flask_cache.cached(timeout=10)
|
||||
def home():
|
||||
# Use the default backend
|
||||
backend_url = get_a_cluster_backend()
|
||||
if backend_url:
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
stats = generate_stats(backend_url)
|
||||
else:
|
||||
backend_info = stats = None
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
stats = generate_stats()
|
||||
|
||||
if not stats['online']:
|
||||
running_model = estimated_wait_sec = 'offline'
|
||||
else:
|
||||
running_model = backend_info['running_model']
|
||||
model_choices, default_backend_info = get_model_choices()
|
||||
|
||||
active_gen_workers = get_active_gen_workers()
|
||||
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
|
||||
if default_backend_info['queued'] == 0 and default_backend_info['queued'] >= opts.concurrent_gens:
|
||||
# There will be a wait if the queue is empty but prompts are processing, but we don't
|
||||
# know how long.
|
||||
estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"
|
||||
default_estimated_wait_sec = f"less than {default_backend_info['estimated_wait']} seconds"
|
||||
else:
|
||||
estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds"
|
||||
default_estimated_wait_sec = f"{default_backend_info['estimated_wait']} seconds"
|
||||
|
||||
if len(config['analytics_tracking_code']):
|
||||
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
|
||||
|
@ -137,39 +118,35 @@ def home():
|
|||
info_html = ''
|
||||
|
||||
mode_info = ''
|
||||
using_vllm = False
|
||||
for k, v in cluster_config.all().items():
|
||||
if v['mode'] == vllm:
|
||||
using_vllm = True
|
||||
break
|
||||
|
||||
if using_vllm == 'vllm':
|
||||
if v['mode'] == 'vllm':
|
||||
mode_info = vllm_info
|
||||
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
break
|
||||
|
||||
return render_template('home.html',
|
||||
llm_middleware_name=opts.llm_middleware_name,
|
||||
analytics_tracking_code=analytics_tracking_code,
|
||||
info_html=info_html,
|
||||
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
|
||||
client_api=f'https://{base_client_api}',
|
||||
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
|
||||
estimated_wait=estimated_wait_sec,
|
||||
default_model=default_backend_info['model'],
|
||||
default_active_gen_workers=default_backend_info['processing'],
|
||||
default_proompters_in_queue=default_backend_info['queued'],
|
||||
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
|
||||
client_api=f'https://{base_client_api}/v2',
|
||||
ws_client_api=f'wss://{base_client_api}/v2/stream' if opts.enable_streaming else 'disabled',
|
||||
default_estimated_wait=default_estimated_wait_sec,
|
||||
mode_name=mode_ui_names[opts.mode][0],
|
||||
api_input_textbox=mode_ui_names[opts.mode][1],
|
||||
streaming_input_textbox=mode_ui_names[opts.mode][2],
|
||||
context_size=opts.context_size,
|
||||
default_context_size=default_backend_info['context_size'],
|
||||
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
|
||||
extra_info=mode_info,
|
||||
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
|
||||
expose_openai_system_prompt=opts.expose_openai_system_prompt,
|
||||
enable_streaming=opts.enable_streaming,
|
||||
model_choices=model_choices
|
||||
)
|
||||
|
||||
|
||||
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
|
||||
|
||||
@app.route('/<first>')
|
||||
@app.route('/<first>/<path:rest>')
|
||||
def fallback(first=None, rest=None):
|
||||
|
|
|
@ -65,6 +65,10 @@
|
|||
.hidden {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.header-workers {
|
||||
font-weight: normal;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
|
@ -76,8 +80,12 @@
|
|||
<h1 style="text-align: center;margin-top: 0;">{{ llm_middleware_name }}</h1>
|
||||
|
||||
<div class="info-box">
|
||||
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
|
||||
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
|
||||
<p><strong>Current Model:</strong> <span id="model">{{ default_model }}</span></p>
|
||||
<p>
|
||||
<strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ default_estimated_wait }}</span><br>
|
||||
Processing: {{ default_active_gen_workers }}<br>
|
||||
Queued: {{ default_proompters_in_queue }}
|
||||
</p>
|
||||
<br>
|
||||
<p><strong>Client API URL:</strong> {{ client_api }}</p>
|
||||
<p><strong>Streaming API URL:</strong> {{ ws_client_api if enable_streaming else 'Disabled' }}</p>
|
||||
|
@ -101,7 +109,7 @@
|
|||
API key</kbd> textbox.
|
||||
</li>
|
||||
<li>Click <kbd>Connect</kbd> to test the connection.</li>
|
||||
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
|
||||
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ default_context_size }}.</li>
|
||||
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>
|
||||
</li>
|
||||
</ol>
|
||||
|
@ -119,9 +127,30 @@
|
|||
|
||||
<br>
|
||||
|
||||
{% for key, value in model_choices.items() %}
|
||||
<div class="info-box">
|
||||
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} workers</span></h3>
|
||||
<p>
|
||||
<strong>Estimated Wait Time:</strong> {{ value.estimated_wait }}<br>
|
||||
Processing: {{ value.processing }}<br>
|
||||
Queued: {{ value.queued }}<br>
|
||||
</p>
|
||||
<p>
|
||||
<strong>Client API URL:</strong> {{ value.client_api }}<br>
|
||||
<strong>Streaming API URL:</strong> {{ value.ws_client_api }}<br>
|
||||
<strong>OpenAI-Compatible API URL:</strong> {{ value.openai_client_api }}
|
||||
</p>
|
||||
<p><strong>Context Size:</strong> {{ value.context_size }}</p>
|
||||
<p><strong>Average Generation Time:</strong> {{ value.avg_generation_time | int }} seconds</p>
|
||||
</div>
|
||||
<br>
|
||||
{% endfor %}
|
||||
|
||||
<!--
|
||||
<div class="info-box">
|
||||
<pre><code class="language-json" style="background-color: white">{{ stats_json|safe }}</code></pre>
|
||||
</div>
|
||||
-->
|
||||
</div>
|
||||
<div class="footer">
|
||||
<a href="https://git.evulid.cc/cyberes/local-llm-server" target="_blank">git.evulid.cc/cyberes/local-llm-server</a>
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
try:
|
||||
import gevent.monkey
|
||||
|
||||
gevent.monkey.patch_all()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import time
|
||||
from threading import Thread
|
||||
from llm_server.cluster.redis_cycle import load_backend_cycle
|
||||
|
||||
from llm_server.cluster.backend import get_backends, get_a_cluster_backend
|
||||
from llm_server.cluster.worker import cluster_worker
|
||||
from llm_server.config.load import parse_backends, load_config
|
||||
from llm_server.cluster.redis_config_cache import RedisClusterStore
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('config')
|
||||
args = parser.parse_args()
|
||||
|
||||
success, config, msg = load_config(args.config)
|
||||
|
||||
cluster_config = RedisClusterStore('cluster_config')
|
||||
cluster_config.clear()
|
||||
cluster_config.load(parse_backends(config))
|
||||
on, off = get_backends()
|
||||
load_backend_cycle('backend_cycler', on + off)
|
||||
|
||||
t = Thread(target=cluster_worker)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
while True:
|
||||
# online, offline = get_backends()
|
||||
# print(online, offline)
|
||||
# print(get_a_cluster_backend())
|
||||
time.sleep(3)
|
Reference in New Issue