Merge cluster to master #3

Merged
cyberes merged 163 commits from cluster into master 2023-10-27 19:19:22 -06:00
38 changed files with 505 additions and 415 deletions
Showing only changes of commit 114f36e709 - Show all commits

View File

@ -3,9 +3,14 @@ import sys
import time
from pathlib import Path
from llm_server.config.load import load_config
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.redis_cycle import redis_cycler_db
from llm_server.cluster.stores import redis_running_models
from llm_server.config.load import load_config, parse_backends
from llm_server.custom_redis import redis
from llm_server.database.create import create_db
from llm_server.routes.queue import priority_queue
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background
script_path = os.path.dirname(os.path.realpath(__file__))
@ -19,16 +24,30 @@ if __name__ == "__main__":
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
redis_cycler_db.flushall()
redis_running_models.flush()
success, config, msg = load_config(config_path)
if not success:
print('Failed to load config:', msg)
sys.exit(1)
create_db()
priority_queue.flush()
cluster_config.clear()
cluster_config.load(parse_backends(config))
print('Loading backend stats...')
generate_stats()
start_background()
redis.set('daemon_started', 1)
print('== Daemon Setup Complete ==\n')
try:
while True:
time.sleep(3600)
except KeyboardInterrupt:
redis.set('daemon_started', 0)

View File

@ -1,23 +1,34 @@
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.cluster.redis_cycle import redis_cycle
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.redis_cycle import add_backend_cycler, redis_cycle
from llm_server.cluster.stores import redis_running_models
from llm_server.llm.info import get_running_model
from llm_server.llm.generator import generator
from llm_server.llm.info import get_info
def test_backend(backend_url: str, mode: str):
running_model, err = get_running_model(backend_url, mode)
if not running_model:
return False
return True
def test_backend(backend_url: str, test_prompt: bool = False):
backend_info = cluster_config.get_backend(backend_url)
if test_prompt:
data = {
"prompt": "Test prompt",
"stream": False,
"temperature": 0,
"max_new_tokens": 16,
}
success, response, err = generator(data, backend_url, timeout=10)
if not success or not response or err:
return False, {}
i = get_info(backend_url, backend_info['mode'])
if not i.get('model'):
return False, {}
return True, i
def get_backends():
cluster_config = RedisClusterStore('cluster_config')
backends = cluster_config.all()
result = {}
for k, v in backends.items():
b = cluster_config.get_backend(k)
status = b['online']
status = b.get('online', False)
priority = b['priority']
result[k] = {'status': status, 'priority': priority}
online_backends = sorted(
@ -33,30 +44,43 @@ def get_backends():
return [url for url, info in online_backends], [url for url, info in offline_backends]
def get_a_cluster_backend():
def get_a_cluster_backend(model=None):
"""
Get a backend from Redis. If there are no online backends, return None.
If `model` is not supplied, we will pick one ourself.
"""
online, offline = get_backends()
cycled = redis_cycle('backend_cycler')
c = cycled.copy()
for i in range(len(cycled)):
if cycled[i] in offline:
del c[c.index(cycled[i])]
if len(c):
return c[0]
if model:
# First, determine if there are multiple backends hosting the same model.
backends_hosting_model = [i.decode('utf-8') for i in redis_running_models.smembers(model)]
# If so, create an iterator for those backends
if len(backends_hosting_model):
add_backend_cycler(model, backends_hosting_model)
cycled = redis_cycle(model)
if len(cycled):
return cycled[0]
else:
# No backend hosting that model
return None
else:
online, _ = get_backends()
if len(online):
return online[0]
def get_backends_from_model(model_name: str):
cluster_config = RedisClusterStore('cluster_config')
a = cluster_config.all()
matches = []
for k, v in a.items():
if v['online'] and v['running_model'] == model_name:
matches.append(k)
return matches
return [x.decode('utf-8') for x in redis_running_models.smembers(model_name)]
# def verify_context_size(model_name:str):
# b = get_backends_from_model(model_name)
# for backend_url in b:
# backend_info = cluster_config.get_backend(backend_url)
# backend_info.get()
def get_running_models():
return redis_running_models.keys()
def purge_backend_from_running_models(backend_url: str):

View File

@ -0,0 +1,88 @@
import numpy as np
from llm_server import opts
from llm_server.cluster.backend import get_a_cluster_backend, get_backends_from_model, get_running_models
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import calculate_wait_time, get_active_gen_workers
# TODO: give this a better name!
def get_model_choices(regen: bool = False):
if not regen:
c = redis.getp('model_choices')
if c:
return c
base_client_api = redis.get('base_client_api', dtype=str)
running_models = get_running_models()
model_choices = {}
for model in running_models:
b = get_backends_from_model(model)
context_size = []
avg_gen_per_worker = []
for backend_url in b:
backend_info = cluster_config.get_backend(backend_url)
if backend_info.get('model_config'):
context_size.append(backend_info['model_config']['max_position_embeddings'])
if backend_info.get('average_generation_elapsed_sec'):
avg_gen_per_worker.append(backend_info['average_generation_elapsed_sec'])
active_gen_workers = get_active_gen_workers(model)
proompters_in_queue = priority_queue.len(model)
if len(avg_gen_per_worker):
average_generation_elapsed_sec = np.average(avg_gen_per_worker)
else:
average_generation_elapsed_sec = 0
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
if proompters_in_queue == 0 and active_gen_workers >= opts.concurrent_gens:
# There will be a wait if the queue is empty but prompts are processing, but we don't
# know how long.
estimated_wait_sec = f"less than {estimated_wait_sec} seconds"
else:
estimated_wait_sec = f"{estimated_wait_sec} seconds"
model_choices[model] = {
'client_api': f'https://{base_client_api}/v2/{model}',
'ws_client_api': f'wss://{base_client_api}/v2/{model}/stream' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/v2/{model}' if opts.enable_openi_compatible_backend else 'disabled',
'backend_count': len(b),
'estimated_wait': estimated_wait_sec,
'queued': proompters_in_queue,
'processing': active_gen_workers,
'avg_generation_time': average_generation_elapsed_sec
}
if len(context_size):
model_choices[model]['context_size'] = min(context_size)
model_choices = dict(sorted(model_choices.items()))
default_backend = get_a_cluster_backend()
default_backend_info = cluster_config.get_backend(default_backend)
default_context_size = default_backend_info['model_config']['max_position_embeddings']
default_average_generation_elapsed_sec = default_backend_info.get('average_generation_elapsed_sec')
default_active_gen_workers = redis.get(f'active_gen_workers:{default_backend}', dtype=int, default=0)
default_proompters_in_queue = priority_queue.len(default_backend_info['model'])
default_estimated_wait_sec = calculate_wait_time(default_average_generation_elapsed_sec, default_proompters_in_queue, default_backend_info['concurrent_gens'], default_active_gen_workers)
default_backend_dict = {
'client_api': f'https://{base_client_api}/v2',
'ws_client_api': f'wss://{base_client_api}/v2' if opts.enable_streaming else None,
'openai_client_api': f'https://{base_client_api}/openai/v2' if opts.enable_openi_compatible_backend else 'disabled',
'estimated_wait': default_estimated_wait_sec,
'queued': default_proompters_in_queue,
'processing': default_active_gen_workers,
'context_size': default_context_size,
'hash': default_backend_info['hash'],
'model': default_backend_info['model'],
'avg_generation_time': default_average_generation_elapsed_sec
}
redis.setp('model_choices', (model_choices, default_backend_dict))
return model_choices, default_backend_dict

View File

@ -44,3 +44,6 @@ class RedisClusterStore:
return result
else:
return {}
# def get(self, name: str):
# return self.all().get(name)

View File

@ -1,21 +1,35 @@
import redis
r = redis.Redis(host='localhost', port=6379, db=9)
redis_cycler_db = redis.Redis(host='localhost', port=6379, db=9)
def redis_cycle(list_name):
while True:
pipe = r.pipeline()
pipe.lpop(list_name)
popped_element = pipe.execute()[0]
if popped_element is None:
return None
r.rpush(list_name, popped_element)
new_list = r.lrange(list_name, 0, -1)
"""
Emulates itertools.cycle() but returns the complete shuffled list.
:param list_name:
:return:
"""
to_move = redis_cycler_db.rpop(list_name)
if not to_move:
return []
redis_cycler_db.lpush(list_name, to_move)
new_list = redis_cycler_db.lrange(list_name, 0, -1)
return [x.decode('utf-8') for x in new_list]
def load_backend_cycle(list_name: str, elements: list):
r.delete(list_name)
for element in elements:
r.rpush(list_name, element)
def add_backend_cycler(list_name: str, new_elements: list):
existing_elements = [i.decode('utf-8') for i in redis_cycler_db.lrange(list_name, 0, -1)]
existing_set = set(existing_elements)
with redis_cycler_db.pipeline() as pipe:
# Add elements
for element in new_elements:
if element not in existing_set:
pipe.rpush(list_name, element)
# Remove elements
for element in existing_set:
if element not in new_elements:
pipe.lrem(list_name, 0, element)
pipe.execute()

View File

@ -1,31 +1,42 @@
from datetime import datetime
import time
from threading import Thread
from llm_server.cluster.backend import purge_backend_from_running_models, test_backend
from llm_server.cluster.backend import test_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.stores import redis_running_models
from llm_server.llm.info import get_running_model
def cluster_worker():
counter = 0
while True:
test_prompt = False
if counter % 4 == 0:
# Only send a test prompt every 120 seconds.
test_prompt = True
threads = []
for n, v in cluster_config.all().items():
thread = Thread(target=check_backend, args=(n, v))
thread = Thread(target=check_backend, args=(n, v, test_prompt))
thread.start()
threads.append(thread)
for thread in threads:
thread.join()
time.sleep(15)
counter += 1
def check_backend(n, v):
# Check if backends are online
# TODO: also have test_backend() get the uptime
online = test_backend(v['backend_url'], v['mode'])
def check_backend(n, v, test_prompt):
online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
# purge_backend_from_running_models(n)
if online:
running_model, err = get_running_model(v['backend_url'], v['mode'])
if not err:
cluster_config.set_backend_value(n, 'running_model', running_model)
purge_backend_from_running_models(n)
running_model = backend_info['model']
for k, v in backend_info.items():
cluster_config.set_backend_value(n, k, v)
redis_running_models.sadd(running_model, n)
else:
for model in redis_running_models.keys():
redis_running_models.srem(model, n)
# redis_running_models.srem(backend_info['model'], n)
# backend_cycler_store.lrem(backend_info['model'], 1, n)
cluster_config.set_backend_value(n, 'online', online)

View File

@ -34,8 +34,9 @@ config_default_vars = {
'openai_moderation_enabled': True,
'netdata_root': None,
'show_backends': True,
'cluster_workers': 30
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
config_required_vars = ['cluster', 'mode', 'llm_middleware_name']
mode_ui_names = {
'oobabooga': ('Text Gen WebUI (ooba)', 'Blocking API url', 'Streaming API url'),

View File

@ -26,7 +26,6 @@ def load_config(config_path):
opts.log_prompts = config['log_prompts']
opts.concurrent_gens = config['concurrent_gens']
opts.frontend_api_client = config['frontend_api_client']
opts.context_size = config['token_limit']
opts.show_num_prompts = config['show_num_prompts']
opts.show_uptime = config['show_uptime']
opts.cluster = config['cluster']
@ -53,6 +52,7 @@ def load_config(config_path):
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
opts.show_backends = config['show_backends']
opts.cluster_workers = config['cluster_workers']
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')

View File

@ -9,17 +9,18 @@ from flask_caching import Cache
from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/15', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
ONE_MONTH_SECONDS = 2678000
class RedisCustom:
class RedisCustom(Redis):
"""
A wrapper class to set prefixes to keys.
"""
def __init__(self, prefix, **kwargs):
super().__init__()
self.redis = Redis(**kwargs)
self.prefix = prefix
try:
@ -108,6 +109,9 @@ class RedisCustom:
):
return self.redis.hincrby(self._key(name), key, amount)
def zcard(self, name: KeyT):
return self.redis.zcard(self._key(name))
def hdel(self, name: str, *keys: List):
return self.redis.hdel(self._key(name), *keys)
@ -129,6 +133,9 @@ class RedisCustom:
):
return self.redis.zadd(self._key(name), mapping, nx, xx, ch, incr, gt, lt)
def lpush(self, name: str, *values: FieldT):
return self.redis.lpush(self._key(name), *values)
def hset(
self,
name: str,
@ -164,6 +171,18 @@ class RedisCustom:
def pipeline(self, transaction=True, shard_hint=None):
return self.redis.pipeline(transaction, shard_hint)
def smembers(self, name: str):
return self.redis.smembers(self._key(name))
def spop(self, name: str, count: Optional[int] = None):
return self.redis.spop(self._key(name), count)
def rpoplpush(self, src, dst):
return self.redis.rpoplpush(src, dst)
def zpopmin(self, name: KeyT, count: Union[int, None] = None):
return self.redis.zpopmin(self._key(name), count)
def exists(self, *names: KeyT):
n = []
for name in names:
@ -196,5 +215,13 @@ class RedisCustom:
self.redis.delete(key)
return flushed
def flushall(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
def flushdb(self, asynchronous: bool = ..., **kwargs) -> bool:
self.flush()
return True
redis = RedisCustom('local_llm')

View File

@ -5,14 +5,14 @@ from threading import Thread
import llm_server
from llm_server import opts
from llm_server.custom_redis import redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.database.conn import database
from llm_server.llm.vllm import tokenize
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens: int = None, is_error: bool = False):
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens: int = None, is_error: bool = False):
def background_task():
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens, is_error
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, backend_url, response_tokens, is_error
# Try not to shove JSON into the database.
if isinstance(response, dict) and response.get('results'):
response = response['results'][0]['text']
@ -23,10 +23,10 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
except:
pass
prompt_tokens = llm_server.llm.get_token_count(prompt)
prompt_tokens = llm_server.llm.get_token_count(prompt, backend_url)
if not is_error:
if not response_tokens:
response_tokens = llm_server.llm.get_token_count(response)
response_tokens = llm_server.llm.get_token_count(response, backend_url)
else:
response_tokens = None
@ -47,7 +47,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
if token:
increment_token_uses(token)
running_model = redis.get('running_model', str, 'ERROR')
running_model = cluster_config.get_backend(backend_url).get('model')
timestamp = int(time.time())
cursor = database.cursor()
try:
@ -56,7 +56,7 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, running_model, opts.mode, cluster_backend, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
(ip, token, running_model, opts.mode, backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()

View File

@ -2,10 +2,10 @@ from llm_server.llm import oobabooga, vllm
from llm_server.custom_redis import redis
def get_token_count(prompt: str):
def get_token_count(prompt: str, backend_url: str):
backend_mode = redis.get('backend_mode', dtype=str)
if backend_mode == 'vllm':
return vllm.tokenize(prompt)
return vllm.tokenize(prompt, backend_url)
elif backend_mode == 'ooba':
return oobabooga.tokenize(prompt)
else:

View File

@ -1,14 +1,13 @@
from llm_server import opts
def generator(request_json_body, cluster_backend):
def generator(request_json_body, cluster_backend, timeout: int = None):
if opts.mode == 'oobabooga':
# from .oobabooga.generate import generate
# return generate(request_json_body)
raise NotImplementedError
elif opts.mode == 'vllm':
from .vllm.generate import generate
r = generate(request_json_body, cluster_backend)
return r
return generate(request_json_body, cluster_backend, timeout=timeout)
else:
raise Exception

View File

@ -20,3 +20,18 @@ def get_running_model(backend_url: str, mode: str):
return False, e
else:
raise Exception
def get_info(backend_url: str, mode: str):
if mode == 'ooba':
return {}
# raise NotImplementedError
elif mode == 'vllm':
try:
r = requests.get(f'{backend_url}/info', verify=opts.verify_ssl, timeout=opts.backend_request_timeout)
j = r.json()
except Exception as e:
return {}
return j
else:
raise Exception

View File

@ -3,13 +3,17 @@ from typing import Tuple, Union
import flask
from llm_server import opts
from llm_server.llm import get_token_count
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.llm import get_token_count
class LLMBackend:
_default_params: dict
def __init__(self, backend_url: str):
self.backend_url = backend_url
def handle_response(self, success, request: flask.Request, response_json_body: dict, response_status_code: int, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError
@ -38,8 +42,9 @@ class LLMBackend:
return True, None
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = get_token_count(prompt)
if prompt_len > opts.context_size - 10:
prompt_len = get_token_count(prompt, self.backend_url)
token_limit = cluster_config.get_backend(self.backend_url)['model_config']['max_position_embeddings']
if prompt_len > token_limit - 10:
model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str)
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {token_limit}, model: {model_name}). Please lower your context size'
return True, None

View File

@ -1,9 +1,9 @@
from flask import jsonify
from llm_server.custom_redis import redis
from ..llm_backend import LLMBackend
from ...database.database import log_prompt
from ...helpers import safe_list_get
from llm_server.custom_redis import redis
from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json
@ -33,7 +33,7 @@ class OobaboogaBackend(LLMBackend):
error_msg = 'Unknown error.'
else:
error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, 'error')
backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url)
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
return jsonify({
'code': 500,
@ -50,7 +50,8 @@ class OobaboogaBackend(LLMBackend):
backend_err = True
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
'error')
error_type='error',
backend_url=self.backend_url)
response_json_body['results'][0]['text'] = backend_response
if not backend_err:
@ -61,7 +62,7 @@ class OobaboogaBackend(LLMBackend):
**response_json_body
}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url)
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
return jsonify({
'code': 500,

View File

@ -24,57 +24,6 @@ def prepare_json(json_data: dict):
return json_data
def transform_to_text(json_request, api_response):
"""
This is to convert a streaming request to a non-streamed request. Don't think this is nessesary.
:param json_request:
:param api_response:
:return:
"""
prompt = transform_prompt_to_text(json_request['messages'])
text = ''
finish_reason = None
for line in api_response.split('\n'):
if line.startswith('data:'):
try:
data = json.loads(line[5:].strip())
except json.decoder.JSONDecodeError:
break
if 'choices' in data:
for choice in data['choices']:
if 'delta' in choice and 'content' in choice['delta']:
text += choice['delta']['content']
if data['choices'][0]['finish_reason']:
finish_reason = data['choices'][0]['finish_reason']
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text))
running_model = redis.get('running_model', 'ERROR', dtype=str)
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
return {
"id": str(uuid4()),
"object": "chat.completion",
"created": int(time.time()),
"model": running_model,
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens
},
"choices": [
{
"message": {
"role": "assistant",
"content": text
},
"finish_reason": finish_reason,
"index": 0
}
]
}
def transform_prompt_to_text(prompt: list):
text = ''
for item in prompt:
@ -82,26 +31,26 @@ def transform_prompt_to_text(prompt: list):
return text.strip('\n')
def handle_blocking_request(json_data: dict, cluster_backend):
def handle_blocking_request(json_data: dict, cluster_backend, timeout: int = 10):
try:
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
except requests.exceptions.ReadTimeout:
print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
# print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
return False, None, 'Request to backend timed out'
except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
# print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False, None, 'Request to backend encountered error'
if r.status_code != 200:
print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
# print(f'Failed to reach VLLM inference endpoint - got code {r.status_code}')
return False, r, f'Backend returned {r.status_code}'
return True, r, None
def generate(json_data: dict, cluster_backend):
def generate(json_data: dict, cluster_backend, timeout: int = None):
if json_data.get('stream'):
try:
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout if not timeout else timeout)
except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
return False
else:
return handle_blocking_request(json_data, cluster_backend)
return handle_blocking_request(json_data, cluster_backend, timeout=timeout)

View File

@ -1,3 +1,7 @@
import requests
from llm_server import opts
vllm_info = """<p><strong>Important:</strong> This endpoint is running <a href="https://github.com/vllm-project/vllm" target="_blank">vllm</a> and not all Oobabooga parameters are supported.</p>
<strong>Supported Parameters:</strong>
<ul>

View File

@ -2,19 +2,21 @@ import requests
import tiktoken
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
def tokenize(prompt: str) -> int:
def tokenize(prompt: str, backend_url: str) -> int:
if not prompt:
# The tokenizers have issues when the prompt is None.
return 0
tokenizer = tiktoken.get_encoding("cl100k_base")
token_limit = cluster_config.get_backend(backend_url)['model_config']['max_position_embeddings']
# First we tokenize it locally to determine if it's worth sending it to the backend.
initial_estimate = len(tokenizer.encode(prompt))
if initial_estimate <= opts.context_size + 200:
if initial_estimate <= token_limit + 200:
try:
r = requests.post(f'{opts.backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
r = requests.post(f'{backend_url}/tokenize', json={'input': prompt}, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
j = r.json()
return j['length']
except Exception as e:

View File

@ -20,7 +20,7 @@ class VLLMBackend(LLMBackend):
backend_response = ''
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
response_tokens=response_json_body.get('details', {}).get('generated_tokens'), backend_url=self.backend_url)
return jsonify({'results': [{'text': backend_response}]}), 200

View File

@ -5,7 +5,6 @@
concurrent_gens = 3
mode = 'oobabooga'
backend_url = None
context_size = 5555
max_new_tokens = 500
auth_required = False
log_prompts = False
@ -38,3 +37,4 @@ openai_silent_trim = False
openai_moderation_enabled = True
cluster = {}
show_backends = True
cluster_workers = 30

View File

@ -1,21 +1,9 @@
import sys
from redis import Redis
from llm_server.custom_redis import redis
from llm_server.routes.v1.generate_stats import generate_stats
def server_startup(s):
if not redis.get('daemon_started', dtype=bool):
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
sys.exit(1)
# Flush the RedisPriorityQueue database.
queue_redis = Redis(host='localhost', port=6379, db=15)
for key in queue_redis.scan_iter('*'):
queue_redis.delete(key)
# Cache the initial stats
print('Loading backend stats...')
generate_stats()

View File

@ -2,13 +2,14 @@ from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
def format_sillytavern_err(msg: str, backend_url: str, level: str = 'info'):
cluster_backend_hash = cluster_config.get_backend_handler(backend_url)['hash']
def format_sillytavern_err(msg: str, backend_url: str = 'none', error_type: str = 'info'):
cluster_backend_hash = cluster_config.get_backend(backend_url)['hash']
http_host = redis.get('http_host', dtype=str)
return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
-> {level.upper()} <-
-> {error_type.upper()} <-
{msg}
BACKEND HASH: {cluster_backend_hash}
```
```
BACKEND: {cluster_backend_hash}
```"""

View File

@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg)
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.cluster_backend, is_error=True)
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.backend_url, is_error=True)
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
@ -40,7 +40,7 @@ class OobaRequestHandler(RequestHandler):
# TODO: how to format this
response_msg = error_msg
else:
response_msg = format_sillytavern_err(error_msg, error_type, self.cluster_backend)
response_msg = format_sillytavern_err(error_msg, error_type=error_type, backend_url=self.backend_url)
return jsonify({
'results': [{'text': response_msg}]

View File

@ -6,7 +6,7 @@ from uuid import uuid4
from redis import Redis
from llm_server import opts
from llm_server.custom_redis import redis
from llm_server.custom_redis import RedisCustom, redis
def increment_ip_count(client_ip: str, redis_key):
@ -20,12 +20,12 @@ def decrement_ip_count(client_ip: str, redis_key):
class RedisPriorityQueue:
def __init__(self):
self.redis = Redis(host='localhost', port=6379, db=15)
def __init__(self, name: str = 'priority_queue', db: int = 12):
self.redis = RedisCustom(name, db=db)
self.pubsub = self.redis.pubsub()
self.pubsub.subscribe('events')
def put(self, item, priority):
def put(self, item, priority, selected_model):
event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit
@ -36,7 +36,7 @@ class RedisPriorityQueue:
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
return None # reject the request
self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority})
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority})
self.increment_ip_count(item[1], 'queued_ip_count')
return event
@ -61,12 +61,23 @@ class RedisPriorityQueue:
def __len__(self):
return self.redis.zcard('queue')
def len(self, model_name):
count = 0
for key in self.redis.zrange('queue', 0, -1):
item = json.loads(key)
if item[2] == model_name:
count += 1
return count
def get_queued_ip_count(self, client_ip: str):
q = self.redis.hget('queued_ip_count', client_ip)
if not q:
return 0
return 0
def flush(self):
self.redis.flush()
class DataEvent:
def __init__(self, event_id=None):
@ -87,12 +98,16 @@ class DataEvent:
priority_queue = RedisPriorityQueue()
def incr_active_workers():
redis.incr('active_gen_workers')
def incr_active_workers(selected_model: str, backend_url: str):
redis.incr(f'active_gen_workers:{selected_model}')
redis.incr(f'active_gen_workers:{backend_url}')
def decr_active_workers():
redis.decr('active_gen_workers')
new_count = redis.get('active_gen_workers', 0, dtype=int)
if new_count < 0:
redis.set('active_gen_workers', 0)
def decr_active_workers(selected_model: str, backend_url: str):
redis.decr(f'active_gen_workers:{selected_model}')
if redis.get(f'active_gen_workers:{selected_model}', 0, dtype=int) < 0:
redis.set(f'active_gen_workers:{selected_model}', 0)
redis.decr(f'active_gen_workers:{backend_url}')
if redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) < 0:
redis.set(f'active_gen_workers:{backend_url}', 0)

View File

@ -15,13 +15,13 @@ from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token
from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
DEFAULT_PRIORITY = 9999
class RequestHandler:
def __init__(self, incoming_request: flask.Request, incoming_json: Union[dict, str] = None):
def __init__(self, incoming_request: flask.Request, selected_model: str, incoming_json: Union[dict, str] = None):
self.request = incoming_request
self.enable_backend_blind_rrd = request.headers.get('LLM-Blind-RRD', False) == 'true'
@ -37,11 +37,12 @@ class RequestHandler:
self.client_ip = self.get_client_ip()
self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
self.cluster_backend = get_a_cluster_backend()
self.cluster_backend_info = cluster_config.get_backend(self.cluster_backend)
self.backend = get_backend_handler(self.cluster_backend)
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
self.backend = get_backend_handler(self.cluster_backend_info['mode'], self.backend_url)
self.parameters = None
self.used = False
self.selected_model = selected_model
redis.zadd('recent_prompters', {self.client_ip: time.time()})
def get_auth_token(self):
@ -123,7 +124,7 @@ class RequestHandler:
backend_response = self.handle_error(combined_error_message, 'Validation Error')
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.cluster_backend, is_error=True)
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.backend_url, is_error=True)
return False, backend_response
return True, (None, 0)
@ -135,14 +136,16 @@ class RequestHandler:
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid:
return (False, None, None, 0), invalid_response
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.cluster_backend), self.token_priority)
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.backend_url), self.token_priority, self.selected_model)
else:
event = None
if not event:
return (False, None, None, 0), self.handle_ratelimited()
# TODO: add wait timeout
success, response, error_msg = event.wait()
end_time = time.time()
elapsed_time = end_time - self.start_time
@ -164,7 +167,7 @@ class RequestHandler:
else:
error_msg = error_msg.strip('.') + '.'
backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
@ -184,7 +187,7 @@ class RequestHandler:
if return_json_err:
error_msg = 'The backend did not return valid JSON.'
backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.backend_url, is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
@ -218,11 +221,11 @@ class RequestHandler:
raise NotImplementedError
def get_backend_handler(mode):
def get_backend_handler(mode, backend_url: str):
if mode == 'oobabooga':
return OobaboogaBackend()
return OobaboogaBackend(backend_url)
elif mode == 'vllm':
return VLLMBackend()
return VLLMBackend(backend_url)
else:
raise Exception

View File

@ -1,6 +1,7 @@
from datetime import datetime
from llm_server.custom_redis import redis
from llm_server.helpers import round_up_base
server_start_time = datetime.now()
@ -14,10 +15,32 @@ def get_total_proompts():
return count
def get_active_gen_workers():
active_gen_workers = redis.get('active_gen_workers')
def get_active_gen_workers(selected_model: str = None, ):
active_gen_workers = redis.get(f'active_gen_workers:{selected_model}')
if active_gen_workers is None:
count = 0
else:
count = int(active_gen_workers)
return count
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
if active_gen_workers < concurrent_gens:
return 0
elif active_gen_workers >= concurrent_gens:
# Calculate how long it will take to complete the currently running gens and the queued requests.
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
# Add gen_time_calc to the time to account for the currently running generations.
# This assumes that all active workers will finish at the same time, which is unlikely.
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
elif proompters_in_queue == 0 and active_gen_workers == 0:
# No queue, no workers
return 0
else:
# No queue
return gen_time_calc

View File

@ -3,18 +3,20 @@ import traceback
from flask import jsonify, request
from . import bp
from ..helpers.client import format_sillytavern_err
from ..helpers.http import validate_json
from ..ooba_request_handler import OobaRequestHandler
from ...cluster.backend import get_a_cluster_backend
from ...cluster.cluster_config import cluster_config
@bp.route('/generate', methods=['POST'])
def generate():
@bp.route('/v1/generate', methods=['POST'])
@bp.route('/<model_name>/v1/generate', methods=['POST'])
def generate(model_name=None):
request_valid_json, request_json_body = validate_json(request)
if not request_valid_json or not request_json_body.get('prompt'):
return jsonify({'code': 400, 'msg': 'Invalid JSON'}), 400
else:
handler = OobaRequestHandler(request)
handler = OobaRequestHandler(request, model_name)
try:
return handler.handle_request()
except Exception:

View File

@ -2,74 +2,32 @@ import time
from datetime import datetime
from llm_server import opts
from llm_server.cluster.backend import get_a_cluster_backend, test_backend
from llm_server.cluster.backend import get_a_cluster_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.model_choices import get_model_choices
from llm_server.custom_redis import redis
from llm_server.database.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort, round_up_base
from llm_server.llm.info import get_running_model
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
if active_gen_workers < concurrent_gens:
return 0
elif active_gen_workers >= concurrent_gens:
# Calculate how long it will take to complete the currently running gens and the queued requests.
# If the proompters in the queue are equal to the number of workers, just use the calculated generation time.
# Otherwise, use how many requests we can process concurrently times the calculated generation time. Then, round
# that number up to the nearest base gen_time_calc (ie. if gen_time_calc is 8 and the calculated number is 11.6, we will get 18). Finally,
# Add gen_time_calc to the time to account for the currently running generations.
# This assumes that all active workers will finish at the same time, which is unlikely.
# Regardless, this is the most accurate estimate we can get without tracking worker elapsed times.
proompters_in_queue_wait_time = gen_time_calc if (proompters_in_queue / concurrent_gens) <= 1 \
else round_up_base(((proompters_in_queue / concurrent_gens) * gen_time_calc), base=gen_time_calc)
return proompters_in_queue_wait_time + gen_time_calc if active_gen_workers > 0 else 0
elif proompters_in_queue == 0 and active_gen_workers == 0:
# No queue, no workers
return 0
else:
# No queue
return gen_time_calc
from llm_server.helpers import deep_sort
from llm_server.routes.stats import get_total_proompts, server_start_time
def generate_stats(regen: bool = False):
if not regen:
c = redis.get('proxy_stats', dtype=dict)
c = redis.getp('proxy_stats')
if c:
return c
default_backend_url = get_a_cluster_backend()
default_backend_info = cluster_config.get_backend(default_backend_url)
if not default_backend_info.get('mode'):
# TODO: remove
print('DAEMON NOT FINISHED STARTING')
return
base_client_api = redis.get('base_client_api', dtype=str)
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
average_generation_elapsed_sec = redis.get('average_generation_elapsed_sec', 0)
online = test_backend(default_backend_url, default_backend_info['mode'])
if online:
running_model, err = get_running_model(default_backend_url, default_backend_info['mode'])
cluster_config.set_backend_value(default_backend_url, 'running_model', running_model)
else:
running_model = None
active_gen_workers = get_active_gen_workers()
proompters_in_queue = len(priority_queue)
# This is so wildly inaccurate it's disabled.
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
# TODO: make this for the currently selected backend
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
output = {
'default': {
'model': running_model,
'backend': default_backend_info['hash'],
'model': default_backend_info['model'],
'backend': default_backend_url,
},
'stats': {
'proompters': {
@ -78,21 +36,14 @@ def generate_stats(regen: bool = False):
},
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
'average_generation_elapsed_sec': int(average_generation_elapsed_sec),
# 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
'num_backends': len(cluster_config.all()) if opts.show_backends else None,
},
'online': online,
'endpoints': {
'blocking': f'https://{base_client_api}',
'streaming': f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
},
'queue': {
'processing': active_gen_workers,
'queued': proompters_in_queue,
'estimated_wait_sec': int(estimated_wait_sec),
},
'timestamp': int(time.time()),
'config': {
'gatekeeper': 'none' if opts.auth_required is False else 'token',
@ -106,42 +57,30 @@ def generate_stats(regen: bool = False):
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
}
# TODO: have get_model_choices() return all the info so we don't have to loop over the backends ourself
if opts.show_backends:
for backend_url, v in cluster_config.all().items():
backend_info = cluster_config.get_backend(backend_url)
if not backend_info['online']:
continue
# TODO: have this fetch the data from VLLM which will display GPU utalization
# if opts.netdata_root:
# netdata_stats = {}
# power_states = get_power_states()
# for gpu, power_state in power_states.items():
# netdata_stats[gpu] = {
# 'power_state': power_state,
# # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
# }
# else:
# netdata_stats = {}
netdata_stats = {}
# TODO: use value returned by VLLM backend here
# backend_uptime = int((datetime.now() - backend_info['start_time']).total_seconds()) if opts.show_uptime else None
backend_uptime = -1
backend_uptime = int((datetime.now() - datetime.fromtimestamp(backend_info['startup_time'])).total_seconds()) if opts.show_uptime else None
output['backend_info'][backend_info['hash']] = {
'uptime': backend_uptime,
# 'context_size': opts.context_size,
'model': opts.manual_model_name if opts.manual_model_name else backend_info.get('running_model', 'ERROR'),
'max_tokens': backend_info['model_config']['max_position_embeddings'],
'model': backend_info['model'],
'mode': backend_info['mode'],
'nvidia': netdata_stats
'nvidia': backend_info['nvidia'],
}
else:
output['backend_info'] = {}
output['default'] = get_model_choices(regen=True)[1]
result = deep_sort(output)
# It may take a bit to get the base client API, so don't cache until then.
if base_client_api:
redis.set_dict('proxy_stats', result) # Cache with no expiry
redis.setp('proxy_stats', result)
return result

View File

@ -10,13 +10,9 @@ from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends
from ...cluster.cluster_config import cluster_config
@bp.route('/model', methods=['GET'])
@bp.route('/<model_name>/model', methods=['GET'])
@bp.route('/v1/model', methods=['GET'])
@bp.route('/<model_name>/v1/model', methods=['GET'])
def get_model(model_name=None):
if not model_name:
b = get_a_cluster_backend()
model_name = cluster_config.get_backend(b)['running_model']
# We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url
@ -25,6 +21,9 @@ def get_model(model_name=None):
if cached_response:
return cached_response
if not model_name:
model_name = cluster_config.get_backend(get_a_cluster_backend()).get('model')
if not is_valid_model(model_name):
response = jsonify({
'code': 400,
@ -32,7 +31,6 @@ def get_model(model_name=None):
}), 400
else:
num_backends = len(get_backends_from_model(model_name))
response = jsonify({
'result': opts.manual_model_name if opts.manual_model_name else model_name,
'model_backend_count': num_backends,
@ -47,7 +45,8 @@ def get_model(model_name=None):
@requires_auth
def get_backend():
online, offline = get_backends()
result = []
result = {}
for i in online + offline:
result.append(cluster_config.get_backend(i))
info = cluster_config.get_backend(i)
result[info['hash']] = info
return jsonify(result), 200

View File

@ -1,7 +1,7 @@
import threading
import time
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.llm.generator import generator
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
@ -9,12 +9,16 @@ from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip
def worker():
while True:
need_to_wait()
(request_json_body, client_ip, token, parameters, cluster_backend), event_id = priority_queue.get()
need_to_wait()
(request_json_body, client_ip, token, parameters, backend_url), event_id, selected_model = priority_queue.get()
if not selected_model:
selected_model = cluster_config.get_backend(backend_url)['model']
# This wait time is "invisible", meaning the worker may as
# well be still waiting to get an item from the queue.
need_to_wait(backend_url)
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers()
incr_active_workers(selected_model, backend_url)
if not request_json_body:
# This was a dummy request from the websocket handler.
@ -22,12 +26,12 @@ def worker():
continue
try:
success, response, error_msg = generator(request_json_body, cluster_backend)
success, response, error_msg = generator(request_json_body, backend_url)
event = DataEvent(event_id)
event.set((success, response, error_msg))
finally:
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers()
decr_active_workers(selected_model, backend_url)
def start_workers(num_workers: int):
@ -40,11 +44,12 @@ def start_workers(num_workers: int):
print(f'Started {i} inference workers.')
def need_to_wait():
def need_to_wait(backend_url: str):
# We need to check the number of active workers since the streaming endpoint may be doing something.
active_workers = redis.get('active_gen_workers', 0, dtype=int)
active_workers = redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int)
concurrent_gens = cluster_config.get_backend(backend_url).get('concurrent_gens', 1)
s = time.time()
while active_workers >= opts.concurrent_gens:
while active_workers >= concurrent_gens:
time.sleep(0.01)
e = time.time()
if e - s > 0.5:

View File

@ -5,7 +5,7 @@ from llm_server.cluster.backend import get_a_cluster_backend, get_backends
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.llm.info import get_info, get_running_model
def main_background_thread():
@ -14,8 +14,9 @@ def main_background_thread():
for backend_url in online:
backend_info = cluster_config.get_backend(backend_url)
backend_mode = backend_info['mode']
running_model, err = get_running_model(backend_url, backend_mode)
if err:
backend_info = get_info(backend_url, backend_mode)
running_model = backend_info.get('model')
if not running_model:
continue
average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode)
@ -25,21 +26,6 @@ def main_background_thread():
cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens)
if average_generation_elapsed_sec and average_output_tokens:
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
default_backend_url = get_a_cluster_backend()
default_backend_info = cluster_config.get_backend(default_backend_url)
default_backend_mode = default_backend_info['mode']
default_running_model, err = get_running_model(default_backend_url, default_backend_mode)
if err:
continue
default_average_generation_elapsed_sec, default_average_output_tokens, default_estimated_avg_tps = calc_stats_for_backend(default_running_model, default_running_model, default_backend_mode)
if default_average_generation_elapsed_sec:
redis.set('average_generation_elapsed_sec', default_average_generation_elapsed_sec)
if default_average_output_tokens:
redis.set('average_output_tokens', default_average_output_tokens)
if default_average_generation_elapsed_sec and default_average_output_tokens:
redis.set('estimated_avg_tps', default_estimated_avg_tps)
time.sleep(30)

View File

@ -1,6 +1,7 @@
import logging
import time
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.routes.queue import priority_queue
@ -17,9 +18,11 @@ if not logger.handlers:
def console_printer():
time.sleep(3)
while True:
processing = redis.hkeys('processing_ips')
processing = redis.keys('active_gen_workers:http*') # backends always start with http
processing_count = 0
for ip in processing:
processing_count += int(redis.hget('processing_ips', ip))
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)}')
if len(processing):
for k in processing:
processing_count += redis.get(k, default=0, dtype=int)
backends = [k for k, v in cluster_config.all().items() if v['online']]
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
time.sleep(10)

View File

@ -15,11 +15,11 @@ from llm_server.workers.recenter import recent_prompters_thread
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(1)
time.sleep(5)
def start_background():
start_workers(opts.concurrent_gens)
start_workers(opts.cluster_workers)
t = Thread(target=main_background_thread)
t.daemon = True

0
other/vllm/vllm_api_server.py Executable file → Normal file
View File

View File

@ -1,20 +1,17 @@
flask~=2.3.3
flask_cors
pyyaml~=6.0.1
flask_caching
requests~=2.31.0
tiktoken~=0.5.0
gunicorn
gevent~=23.9.0.post1
async-timeout
flask-sock
uvicorn~=0.23.2
fastapi~=0.103.1
torch~=2.0.1
PyMySQL~=1.1.0
DBUtils~=3.0.3
simplejson~=3.19.1
websockets~=11.0.3
basicauth~=1.0.0
openai~=0.28.0
urllib3~=2.0.4
flask-sock==0.6.0
gunicorn==21.2.0
redis==5.0.1
git+https://github.com/vllm-project/vllm

View File

@ -1,5 +1,3 @@
from llm_server.cluster.cluster_config import cluster_config
try:
import gevent.monkey
@ -14,10 +12,10 @@ from pathlib import Path
import simplejson as json
from flask import Flask, jsonify, render_template, request
from llm_server.cluster.backend import get_a_cluster_backend, get_backends
from llm_server.cluster.redis_cycle import load_backend_cycle
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.model_choices import get_model_choices
from llm_server.config.config import mode_ui_names
from llm_server.config.load import load_config, parse_backends
from llm_server.config.load import load_config
from llm_server.database.conn import database
from llm_server.database.create import create_db
from llm_server.pre_fork import server_startup
@ -26,10 +24,7 @@ from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.sock import init_socketio
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
# TODO: implement background thread to test backends via sending test prompts
# TODO: if backend fails request, mark it as down
# TODO: per-backend workers
# TODO: allow setting concurrent gens per-backend
# TODO: set the max tokens to that of the lowest backend
# TODO: implement RRD backend loadbalancer option
@ -42,6 +37,7 @@ from llm_server.sock import init_socketio
# TODO: have VLLM report context size, uptime
# Lower priority
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
# TODO: the estiamted wait time lags behind the stats
# TODO: simulate OpenAI error messages regardless of endpoint
@ -69,12 +65,11 @@ from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.custom_redis import flask_cache
from llm_server.llm import redis
from llm_server.routes.stats import get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats
app = Flask(__name__)
init_socketio(app)
app.register_blueprint(bp, url_prefix='/api/v1/')
app.register_blueprint(bp, url_prefix='/api/v2/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
flask_cache.init_app(app)
flask_cache.clear()
@ -94,37 +89,23 @@ if not success:
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db()
cluster_config.clear()
cluster_config.load(parse_backends(config))
on, off = get_backends()
load_backend_cycle('backend_cycler', on + off)
@app.route('/')
@app.route('/api')
@app.route('/api/openai')
@flask_cache.cached(timeout=10)
def home():
# Use the default backend
backend_url = get_a_cluster_backend()
if backend_url:
backend_info = cluster_config.get_backend(backend_url)
stats = generate_stats(backend_url)
else:
backend_info = stats = None
base_client_api = redis.get('base_client_api', dtype=str)
stats = generate_stats()
if not stats['online']:
running_model = estimated_wait_sec = 'offline'
else:
running_model = backend_info['running_model']
model_choices, default_backend_info = get_model_choices()
active_gen_workers = get_active_gen_workers()
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
if default_backend_info['queued'] == 0 and default_backend_info['queued'] >= opts.concurrent_gens:
# There will be a wait if the queue is empty but prompts are processing, but we don't
# know how long.
estimated_wait_sec = f"less than {stats['stats']['average_generation_elapsed_sec']} seconds"
default_estimated_wait_sec = f"less than {default_backend_info['estimated_wait']} seconds"
else:
estimated_wait_sec = f"{stats['queue']['estimated_wait_sec']} seconds"
default_estimated_wait_sec = f"{default_backend_info['estimated_wait']} seconds"
if len(config['analytics_tracking_code']):
analytics_tracking_code = f"<script>\n{config['analytics_tracking_code']}\n</script>"
@ -137,39 +118,35 @@ def home():
info_html = ''
mode_info = ''
using_vllm = False
for k, v in cluster_config.all().items():
if v['mode'] == vllm:
using_vllm = True
break
if using_vllm == 'vllm':
if v['mode'] == 'vllm':
mode_info = vllm_info
base_client_api = redis.get('base_client_api', dtype=str)
break
return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name,
analytics_tracking_code=analytics_tracking_code,
info_html=info_html,
current_model=opts.manual_model_name if opts.manual_model_name else running_model,
client_api=f'https://{base_client_api}',
ws_client_api=f'wss://{base_client_api}/v1/stream' if opts.enable_streaming else None,
estimated_wait=estimated_wait_sec,
default_model=default_backend_info['model'],
default_active_gen_workers=default_backend_info['processing'],
default_proompters_in_queue=default_backend_info['queued'],
current_model=opts.manual_model_name if opts.manual_model_name else None, # else running_model,
client_api=f'https://{base_client_api}/v2',
ws_client_api=f'wss://{base_client_api}/v2/stream' if opts.enable_streaming else 'disabled',
default_estimated_wait=default_estimated_wait_sec,
mode_name=mode_ui_names[opts.mode][0],
api_input_textbox=mode_ui_names[opts.mode][1],
streaming_input_textbox=mode_ui_names[opts.mode][2],
context_size=opts.context_size,
default_context_size=default_backend_info['context_size'],
stats_json=json.dumps(stats, indent=4, ensure_ascii=False),
extra_info=mode_info,
openai_client_api=f'https://{base_client_api}/openai/v1' if opts.enable_openi_compatible_backend else 'disabled',
expose_openai_system_prompt=opts.expose_openai_system_prompt,
enable_streaming=opts.enable_streaming,
model_choices=model_choices
)
# TODO: add authenticated route to get the current backend URL. Add it to /v1/backend
@app.route('/<first>')
@app.route('/<first>/<path:rest>')
def fallback(first=None, rest=None):

View File

@ -65,6 +65,10 @@
.hidden {
display: none;
}
.header-workers {
font-weight: normal;
}
</style>
</head>
@ -76,8 +80,12 @@
<h1 style="text-align: center;margin-top: 0;">{{ llm_middleware_name }}</h1>
<div class="info-box">
<p><strong>Current Model:</strong> <span id="model">{{ current_model }}</span></p>
<p><strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ estimated_wait }}</span></p>
<p><strong>Current Model:</strong> <span id="model">{{ default_model }}</span></p>
<p>
<strong>Estimated Wait Time:</strong> <span id="estimatedWait">{{ default_estimated_wait }}</span><br>
Processing: {{ default_active_gen_workers }}<br>
Queued: {{ default_proompters_in_queue }}
</p>
<br>
<p><strong>Client API URL:</strong> {{ client_api }}</p>
<p><strong>Streaming API URL:</strong> {{ ws_client_api if enable_streaming else 'Disabled' }}</p>
@ -101,7 +109,7 @@
API key</kbd> textbox.
</li>
<li>Click <kbd>Connect</kbd> to test the connection.</li>
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ context_size }}.</li>
<li>Open your preset config and set <kbd>Context Size</kbd> to {{ default_context_size }}.</li>
<li>Follow this guide to get set up: <a href="https://rentry.org/freellamas" target="_blank">rentry.org/freellamas</a>
</li>
</ol>
@ -119,9 +127,30 @@
<br>
{% for key, value in model_choices.items() %}
<div class="info-box">
<h3>{{ key }} <span class="header-workers">- {{ value.backend_count }} workers</span></h3>
<p>
<strong>Estimated Wait Time:</strong> {{ value.estimated_wait }}<br>
Processing: {{ value.processing }}<br>
Queued: {{ value.queued }}<br>
</p>
<p>
<strong>Client API URL:</strong> {{ value.client_api }}<br>
<strong>Streaming API URL:</strong> {{ value.ws_client_api }}<br>
<strong>OpenAI-Compatible API URL:</strong> {{ value.openai_client_api }}
</p>
<p><strong>Context Size:</strong> {{ value.context_size }}</p>
<p><strong>Average Generation Time:</strong> {{ value.avg_generation_time | int }} seconds</p>
</div>
<br>
{% endfor %}
<!--
<div class="info-box">
<pre><code class="language-json" style="background-color: white">{{ stats_json|safe }}</code></pre>
</div>
-->
</div>
<div class="footer">
<a href="https://git.evulid.cc/cyberes/local-llm-server" target="_blank">git.evulid.cc/cyberes/local-llm-server</a>

View File

@ -1,39 +0,0 @@
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import time
from threading import Thread
from llm_server.cluster.redis_cycle import load_backend_cycle
from llm_server.cluster.backend import get_backends, get_a_cluster_backend
from llm_server.cluster.worker import cluster_worker
from llm_server.config.load import parse_backends, load_config
from llm_server.cluster.redis_config_cache import RedisClusterStore
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('config')
args = parser.parse_args()
success, config, msg = load_config(args.config)
cluster_config = RedisClusterStore('cluster_config')
cluster_config.clear()
cluster_config.load(parse_backends(config))
on, off = get_backends()
load_backend_cycle('backend_cycler', on + off)
t = Thread(target=cluster_worker)
t.daemon = True
t.start()
while True:
# online, offline = get_backends()
# print(online, offline)
# print(get_a_cluster_backend())
time.sleep(3)