This commit is contained in:
Cyberes 2023-09-29 00:09:44 -06:00
parent e7b57cad7b
commit 624ca74ce5
47 changed files with 506 additions and 390 deletions

View File

@ -1,22 +1,12 @@
import time
from llm_server.custom_redis import redis
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import os import os
import sys import sys
import time
from pathlib import Path from pathlib import Path
from llm_server.config.load import load_config from llm_server.config.load import load_config
from llm_server.custom_redis import redis
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.workers.threader import start_background
from llm_server.workers.app import start_background
script_path = os.path.dirname(os.path.realpath(__file__)) script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH") config_path_environ = os.getenv("CONFIG_PATH")
@ -29,7 +19,7 @@ if __name__ == "__main__":
flushed_keys = redis.flush() flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.') print('Flushed', len(flushed_keys), 'keys from Redis.')
success, config, msg = load_config(config_path, script_path) success, config, msg = load_config(config_path)
if not success: if not success:
print('Failed to load config:', msg) print('Failed to load config:', msg)
sys.exit(1) sys.exit(1)

View File

@ -0,0 +1,71 @@
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.cluster.redis_cycle import redis_cycle
from llm_server.cluster.stores import redis_running_models
from llm_server.llm.info import get_running_model
def test_backend(backend_url: str, mode: str):
running_model, err = get_running_model(backend_url, mode)
if not running_model:
return False
return True
def get_backends():
cluster_config = RedisClusterStore('cluster_config')
backends = cluster_config.all()
result = {}
for k, v in backends.items():
b = cluster_config.get_backend(k)
status = b['online']
priority = b['priority']
result[k] = {'status': status, 'priority': priority}
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: -kv[1]['priority'],
reverse=True
)
offline_backends = sorted(
((url, info) for url, info in backends.items() if not info['online']),
key=lambda kv: -kv[1]['priority'],
reverse=True
)
return [url for url, info in online_backends], [url for url, info in offline_backends]
def get_a_cluster_backend():
"""
Get a backend from Redis. If there are no online backends, return None.
"""
online, offline = get_backends()
cycled = redis_cycle('backend_cycler')
c = cycled.copy()
for i in range(len(cycled)):
if cycled[i] in offline:
del c[c.index(cycled[i])]
if len(c):
return c[0]
else:
return None
def get_backends_from_model(model_name: str):
cluster_config = RedisClusterStore('cluster_config')
a = cluster_config.all()
matches = []
for k, v in a.items():
if v['online'] and v['running_model'] == model_name:
matches.append(k)
return matches
def purge_backend_from_running_models(backend_url: str):
keys = redis_running_models.keys()
pipeline = redis_running_models.pipeline()
for model in keys:
pipeline.srem(model, backend_url)
pipeline.execute()
def is_valid_model(model_name: str):
return redis_running_models.exists(model_name)

View File

@ -0,0 +1,3 @@
from llm_server.cluster.redis_config_cache import RedisClusterStore
cluster_config = RedisClusterStore('cluster_config')

View File

@ -1,26 +0,0 @@
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.llm.info import get_running_model
def test_backend(backend_url: str):
running_model, err = get_running_model(backend_url)
if not running_model:
return False
return True
def get_best_backends():
cluster_config = RedisClusterStore('cluster_config')
backends = cluster_config.all()
result = {}
for k, v in backends.items():
b = cluster_config.get_backend(k)
status = b['online']
priority = b['priority']
result[k] = {'status': status, 'priority': priority}
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: kv[1]['priority'],
reverse=True
)
return [url for url, info in online_backends]

View File

@ -1,3 +1,4 @@
import hashlib
import pickle import pickle
from llm_server.custom_redis import RedisCustom from llm_server.custom_redis import RedisCustom
@ -13,14 +14,17 @@ class RedisClusterStore:
def load(self, config: dict): def load(self, config: dict):
for k, v in config.items(): for k, v in config.items():
self.set_backend(k, v) self.add_backend(k, v)
def set_backend(self, name: str, values: dict): def add_backend(self, name: str, values: dict):
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()}) self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
self.set_backend_value(name, 'online', False) self.set_backend_value(name, 'online', False)
h = hashlib.sha256(name.encode('utf-8')).hexdigest()
self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}')
def set_backend_value(self, key: str, name: str, value): def set_backend_value(self, backend: str, key: str, value):
self.config_redis.hset(key, name, pickle.dumps(value)) # By storing the value as a pickle we don't have to cast anything when getting the value from Redis.
self.config_redis.hset(backend, key, pickle.dumps(value))
def get_backend(self, name: str): def get_backend(self, name: str):
r = self.config_redis.hgetall(name) r = self.config_redis.hgetall(name)

View File

@ -0,0 +1,21 @@
import redis
r = redis.Redis(host='localhost', port=6379, db=9)
def redis_cycle(list_name):
while True:
pipe = r.pipeline()
pipe.lpop(list_name)
popped_element = pipe.execute()[0]
if popped_element is None:
return None
r.rpush(list_name, popped_element)
new_list = r.lrange(list_name, 0, -1)
return [x.decode('utf-8') for x in new_list]
def load_backend_cycle(list_name: str, elements: list):
r.delete(list_name)
for element in elements:
r.rpush(list_name, element)

View File

@ -0,0 +1,3 @@
from llm_server.custom_redis import RedisCustom
redis_running_models = RedisCustom('running_models')

View File

@ -1,10 +1,10 @@
import time from datetime import datetime
from threading import Thread from threading import Thread
from llm_server.cluster.funcs.backend import test_backend from llm_server.cluster.backend import purge_backend_from_running_models, test_backend
from llm_server.cluster.redis_config_cache import RedisClusterStore from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.stores import redis_running_models
cluster_config = RedisClusterStore('cluster_config') from llm_server.llm.info import get_running_model
def cluster_worker(): def cluster_worker():
@ -16,10 +16,16 @@ def cluster_worker():
threads.append(thread) threads.append(thread)
for thread in threads: for thread in threads:
thread.join() thread.join()
time.sleep(10)
def check_backend(n, v): def check_backend(n, v):
# Check if backends are online # Check if backends are online
online = test_backend(v['backend_url']) # TODO: also have test_backend() get the uptime
online = test_backend(v['backend_url'], v['mode'])
if online:
running_model, err = get_running_model(v['backend_url'], v['mode'])
if not err:
cluster_config.set_backend_value(n, 'running_model', running_model)
purge_backend_from_running_models(n)
redis_running_models.sadd(running_model, n)
cluster_config.set_backend_value(n, 'online', online) cluster_config.set_backend_value(n, 'online', online)

View File

@ -32,7 +32,8 @@ config_default_vars = {
'openai_org_name': 'OpenAI', 'openai_org_name': 'OpenAI',
'openai_silent_trim': False, 'openai_silent_trim': False,
'openai_moderation_enabled': True, 'openai_moderation_enabled': True,
'netdata_root': None 'netdata_root': None,
'show_backends': True,
} }
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -52,6 +52,7 @@ def load_config(config_path):
opts.openai_org_name = config['openai_org_name'] opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim'] opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled'] opts.openai_moderation_enabled = config['openai_moderation_enabled']
opts.show_backends = config['show_backends']
if opts.openai_expose_our_model and not opts.openai_api_key: if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')

View File

@ -1,13 +1,13 @@
import pickle import pickle
import sys import sys
import traceback import traceback
from typing import Callable, List, Mapping, Union, Optional from typing import Callable, List, Mapping, Optional, Union
import redis as redis_pkg import redis as redis_pkg
import simplejson as json import simplejson as json
from flask_caching import Cache from flask_caching import Cache
from redis import Redis from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT, PatternT from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
@ -35,12 +35,12 @@ class RedisCustom:
def set(self, key, value, ex: Union[ExpiryT, None] = None): def set(self, key, value, ex: Union[ExpiryT, None] = None):
return self.redis.set(self._key(key), value, ex=ex) return self.redis.set(self._key(key), value, ex=ex)
def get(self, key, dtype=None, default=None): def get(self, key, default=None, dtype=None):
""" # TODO: use pickle
:param key: import inspect
:param dtype: convert to this type if inspect.isclass(default):
:return: raise Exception
"""
d = self.redis.get(self._key(key)) d = self.redis.get(self._key(key))
if dtype and d: if dtype and d:
try: try:
@ -153,11 +153,23 @@ class RedisCustom:
keys = [] keys = []
for key in raw_keys: for key in raw_keys:
p = key.decode('utf-8').split(':') p = key.decode('utf-8').split(':')
if len(p) > 2: if len(p) >= 2:
# Delete prefix
del p[0] del p[0]
keys.append(':'.join(p)) k = ':'.join(p)
if k != '____':
keys.append(k)
return keys return keys
def pipeline(self, transaction=True, shard_hint=None):
return self.redis.pipeline(transaction, shard_hint)
def exists(self, *names: KeyT):
n = []
for name in names:
n.append(self._key(name))
return self.redis.exists(*n)
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex) return self.set(key, json.dumps(dict_value), ex=ex)
@ -174,7 +186,7 @@ class RedisCustom:
def getp(self, name: str): def getp(self, name: str):
r = self.redis.get(name) r = self.redis.get(name)
if r: if r:
return pickle.load(r) return pickle.loads(r)
return r return r
def flush(self): def flush(self):

View File

@ -1,60 +1,69 @@
import json import json
import time import time
import traceback import traceback
from threading import Thread
import llm_server import llm_server
from llm_server import opts from llm_server import opts
from llm_server.custom_redis import redis
from llm_server.database.conn import database from llm_server.database.conn import database
from llm_server.llm.vllm import tokenize from llm_server.llm.vllm import tokenize
from llm_server.custom_redis import redis
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False): def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens: int = None, is_error: bool = False):
if isinstance(response, dict) and response.get('results'): def background_task():
response = response['results'][0]['text'] nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens, is_error
try: # Try not to shove JSON into the database.
j = json.loads(response) if isinstance(response, dict) and response.get('results'):
if j.get('results'): response = response['results'][0]['text']
response = j['results'][0]['text'] try:
except: j = json.loads(response)
pass if j.get('results'):
response = j['results'][0]['text']
except:
pass
prompt_tokens = llm_server.llm.get_token_count(prompt) prompt_tokens = llm_server.llm.get_token_count(prompt)
if not is_error: if not is_error:
if not response_tokens: if not response_tokens:
response_tokens = llm_server.llm.get_token_count(response) response_tokens = llm_server.llm.get_token_count(response)
else: else:
response_tokens = None response_tokens = None
# Sometimes we may want to insert null into the DB, but # Sometimes we may want to insert null into the DB, but
# usually we want to insert a float. # usually we want to insert a float.
if gen_time: if gen_time:
gen_time = round(gen_time, 3) gen_time = round(gen_time, 3)
if is_error: if is_error:
gen_time = None gen_time = None
if not opts.log_prompts: if not opts.log_prompts:
prompt = None prompt = None
if not opts.log_prompts and not is_error: if not opts.log_prompts and not is_error:
# TODO: test and verify this works as expected # TODO: test and verify this works as expected
response = None response = None
if token: if token:
increment_token_uses(token) increment_token_uses(token)
running_model = redis.get('running_model', str, 'ERROR') running_model = redis.get('running_model', str, 'ERROR')
timestamp = int(time.time()) timestamp = int(time.time())
cursor = database.cursor() cursor = database.cursor()
try: try:
cursor.execute(""" cursor.execute("""
INSERT INTO prompts INSERT INTO prompts
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp) (ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", """,
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp)) (ip, token, running_model, opts.mode, cluster_backend, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally: finally:
cursor.close() cursor.close()
# TODO: use async/await instead of threads
thread = Thread(target=background_task)
thread.start()
thread.join()
def is_valid_api_key(api_key): def is_valid_api_key(api_key):

View File

@ -60,7 +60,7 @@ def round_up_base(n, base):
def auto_set_base_client_api(request): def auto_set_base_client_api(request):
http_host = redis.get('http_host', str) http_host = redis.get('http_host', dtype=str)
host = request.headers.get("Host") host = request.headers.get("Host")
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host): if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
# If the current http_host is not an IP, don't do anything. # If the current http_host is not an IP, don't do anything.

View File

@ -1,12 +0,0 @@
import threading
class ThreadSafeInteger:
def __init__(self, value=0):
self.value = value
self._value_lock = threading.Lock()
def increment(self):
with self._value_lock:
self.value += 1
return self.value

View File

@ -3,7 +3,7 @@ from llm_server.custom_redis import redis
def get_token_count(prompt: str): def get_token_count(prompt: str):
backend_mode = redis.get('backend_mode', str) backend_mode = redis.get('backend_mode', dtype=str)
if backend_mode == 'vllm': if backend_mode == 'vllm':
return vllm.tokenize(prompt) return vllm.tokenize(prompt)
elif backend_mode == 'ooba': elif backend_mode == 'ooba':

View File

@ -1,14 +1,14 @@
from llm_server import opts from llm_server import opts
def generator(request_json_body): def generator(request_json_body, cluster_backend):
if opts.mode == 'oobabooga': if opts.mode == 'oobabooga':
# from .oobabooga.generate import generate # from .oobabooga.generate import generate
# return generate(request_json_body) # return generate(request_json_body)
raise NotImplementedError raise NotImplementedError
elif opts.mode == 'vllm': elif opts.mode == 'vllm':
from .vllm.generate import generate from .vllm.generate import generate
r = generate(request_json_body) r = generate(request_json_body, cluster_backend)
return r return r
else: else:
raise Exception raise Exception

View File

@ -3,19 +3,15 @@ import requests
from llm_server import opts from llm_server import opts
def get_running_model(backend_url: str): def get_running_model(backend_url: str, mode: str):
# TODO: remove this once we go to Redis if mode == 'ooba':
if not backend_url:
backend_url = opts.backend_url
if opts.mode == 'oobabooga':
try: try:
backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json() r_json = backend_response.json()
return r_json['result'], None return r_json['result'], None
except Exception as e: except Exception as e:
return False, e return False, e
elif opts.mode == 'vllm': elif mode == 'vllm':
try: try:
backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl) backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json() r_json = backend_response.json()

View File

@ -40,6 +40,6 @@ class LLMBackend:
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]: def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = get_token_count(prompt) prompt_len = get_token_count(prompt)
if prompt_len > opts.context_size - 10: if prompt_len > opts.context_size - 10:
model_name = redis.get('running_model', str, 'NO MODEL ERROR') model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str)
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size' return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
return True, None return True, None

View File

@ -34,7 +34,7 @@ def build_openai_response(prompt, response, model=None):
# TODO: async/await # TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt) prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response) response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR') running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({ response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}", "id": f"chatcmpl-{generate_oai_string(30)}",
@ -57,7 +57,7 @@ def build_openai_response(prompt, response, model=None):
} }
}), 200) }), 200)
stats = redis.get('proxy_stats', dict) stats = redis.get('proxy_stats', dtype=dict)
if stats: if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response return response

View File

@ -49,7 +49,7 @@ def transform_to_text(json_request, api_response):
prompt_tokens = len(llm_server.llm.get_token_count(prompt)) prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text)) completion_tokens = len(llm_server.llm.get_token_count(text))
running_model = redis.get('running_model', str, 'ERROR') running_model = redis.get('running_model', 'ERROR', dtype=str)
# https://platform.openai.com/docs/api-reference/making-requests?lang=python # https://platform.openai.com/docs/api-reference/making-requests?lang=python
return { return {
@ -82,9 +82,9 @@ def transform_prompt_to_text(prompt: list):
return text.strip('\n') return text.strip('\n')
def handle_blocking_request(json_data: dict): def handle_blocking_request(json_data: dict, cluster_backend):
try: try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
except requests.exceptions.ReadTimeout: except requests.exceptions.ReadTimeout:
print(f'Failed to reach VLLM inference endpoint - request to backend timed out') print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
return False, None, 'Request to backend timed out' return False, None, 'Request to backend timed out'
@ -97,11 +97,11 @@ def handle_blocking_request(json_data: dict):
return True, r, None return True, r, None
def generate(json_data: dict): def generate(json_data: dict, cluster_backend):
if json_data.get('stream'): if json_data.get('stream'):
try: try:
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout) return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
except Exception as e: except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}') print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
else: else:
return handle_blocking_request(json_data) return handle_blocking_request(json_data, cluster_backend)

View File

@ -19,16 +19,8 @@ class VLLMBackend(LLMBackend):
# Failsafe # Failsafe
backend_response = '' backend_response = ''
r_url = request.url log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
def background_task():
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
return jsonify({'results': [{'text': backend_response}]}), 200 return jsonify({'results': [{'text': backend_response}]}), 200

View File

@ -2,7 +2,6 @@
# TODO: rewrite the config system so I don't have to add every single config default here # TODO: rewrite the config system so I don't have to add every single config default here
running_model = 'ERROR'
concurrent_gens = 3 concurrent_gens = 3
mode = 'oobabooga' mode = 'oobabooga'
backend_url = None backend_url = None
@ -38,3 +37,4 @@ openai_org_name = 'OpenAI'
openai_silent_trim = False openai_silent_trim = False
openai_moderation_enabled = True openai_moderation_enabled = True
cluster = {} cluster = {}
show_backends = True

View File

@ -7,7 +7,7 @@ from llm_server.routes.v1.generate_stats import generate_stats
def server_startup(s): def server_startup(s):
if not redis.get('daemon_started', bool): if not redis.get('daemon_started', dtype=bool):
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?') print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
sys.exit(1) sys.exit(1)

View File

@ -1,10 +1,14 @@
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
def format_sillytavern_err(msg: str, level: str = 'info'): def format_sillytavern_err(msg: str, backend_url: str, level: str = 'info'):
http_host = redis.get('http_host', str) cluster_backend_hash = cluster_config.get_backend_handler(backend_url)['hash']
http_host = redis.get('http_host', dtype=str)
return f"""``` return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} === === MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
-> {level.upper()} <- -> {level.upper()} <-
{msg} {msg}
BACKEND HASH: {cluster_backend_hash}
```""" ```"""

View File

@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg) backend_response = self.handle_error(msg)
if do_log: if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.cluster_backend, is_error=True)
return backend_response[0], 200 # We only return the response from handle_error(), not the error code return backend_response[0], 200 # We only return the response from handle_error(), not the error code
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
@ -40,7 +40,7 @@ class OobaRequestHandler(RequestHandler):
# TODO: how to format this # TODO: how to format this
response_msg = error_msg response_msg = error_msg
else: else:
response_msg = format_sillytavern_err(error_msg, error_type) response_msg = format_sillytavern_err(error_msg, error_type, self.cluster_backend)
return jsonify({ return jsonify({
'results': [{'text': response_msg}] 'results': [{'text': response_msg}]

View File

@ -5,11 +5,12 @@ import traceback
from flask import Response, jsonify, request from flask import Response, jsonify, request
from . import openai_bp
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from . import openai_bp
from ..helpers.http import validate_json from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler from ..openai_request_handler import OpenAIRequestHandler
from ... import opts from ... import opts
from ...cluster.backend import get_a_cluster_backend
from ...database.database import log_prompt from ...database.database import log_prompt
from ...llm.generator import generator from ...llm.generator import generator
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
@ -48,10 +49,11 @@ def openai_chat_completions():
'stream': True, 'stream': True,
} }
try: try:
response = generator(msg_to_backend) cluster_backend = get_a_cluster_backend()
response = generator(msg_to_backend, cluster_backend)
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model') model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
def generate(): def generate():
@ -94,7 +96,7 @@ def openai_chat_completions():
def background_task(): def background_task():
generated_tokens = tokenize(generated_text) generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens) log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=generated_tokens)
# TODO: use async/await instead of threads # TODO: use async/await instead of threads
thread = threading.Thread(target=background_task) thread = threading.Thread(target=background_task)

View File

@ -29,7 +29,7 @@ def openai_completions():
# TODO: async/await # TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt']) prompt_tokens = get_token_count(request_json_body['prompt'])
response_tokens = get_token_count(output) response_tokens = get_token_count(output)
running_model = redis.get('running_model', str, 'ERROR') running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({ response = make_response(jsonify({
"id": f"cmpl-{generate_oai_string(30)}", "id": f"cmpl-{generate_oai_string(30)}",
@ -51,7 +51,7 @@ def openai_completions():
} }
}), 200) }), 200)
stats = redis.get('proxy_stats', dict) stats = redis.get('proxy_stats', dtype=dict)
if stats: if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response return response

View File

@ -7,6 +7,7 @@ from . import openai_bp
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from ..stats import server_start_time from ..stats import server_start_time
from ... import opts from ... import opts
from ...cluster.backend import get_a_cluster_backend
from ...helpers import jsonify_pretty from ...helpers import jsonify_pretty
from ...llm.info import get_running_model from ...llm.info import get_running_model
@ -22,7 +23,7 @@ def openai_list_models():
'type': error.__class__.__name__ 'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us }), 500 # return 500 so Cloudflare doesn't intercept us
else: else:
running_model = redis.get('running_model', str, 'ERROR') running_model = redis.get('running_model', 'ERROR', dtype=str)
oai = fetch_openai_models() oai = fetch_openai_models()
r = [] r = []
if opts.openai_expose_our_model: if opts.openai_expose_our_model:

View File

@ -93,6 +93,6 @@ def incr_active_workers():
def decr_active_workers(): def decr_active_workers():
redis.decr('active_gen_workers') redis.decr('active_gen_workers')
new_count = redis.get('active_gen_workers', int, 0) new_count = redis.get('active_gen_workers', 0, dtype=int)
if new_count < 0: if new_count < 0:
redis.set('active_gen_workers', 0) redis.set('active_gen_workers', 0)

View File

@ -5,13 +5,15 @@ import flask
from flask import Response, request from flask import Response, request
from llm_server import opts from llm_server import opts
from llm_server.cluster.backend import get_a_cluster_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.database.conn import database from llm_server.database.conn import database
from llm_server.database.database import log_prompt from llm_server.database.database import log_prompt
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token from llm_server.routes.auth import parse_token
from llm_server.custom_redis import redis
from llm_server.routes.helpers.http import require_api_key, validate_json from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
@ -35,7 +37,9 @@ class RequestHandler:
self.client_ip = self.get_client_ip() self.client_ip = self.get_client_ip()
self.token = self.get_auth_token() self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit() self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
self.backend = get_backend() self.cluster_backend = get_a_cluster_backend()
self.cluster_backend_info = cluster_config.get_backend(self.cluster_backend)
self.backend = get_backend_handler(self.cluster_backend)
self.parameters = None self.parameters = None
self.used = False self.used = False
redis.zadd('recent_prompters', {self.client_ip: time.time()}) redis.zadd('recent_prompters', {self.client_ip: time.time()})
@ -119,7 +123,7 @@ class RequestHandler:
backend_response = self.handle_error(combined_error_message, 'Validation Error') backend_response = self.handle_error(combined_error_message, 'Validation Error')
if do_log: if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.cluster_backend, is_error=True)
return False, backend_response return False, backend_response
return True, (None, 0) return True, (None, 0)
@ -131,7 +135,7 @@ class RequestHandler:
request_valid, invalid_response = self.validate_request(prompt, do_log=True) request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid: if not request_valid:
return (False, None, None, 0), invalid_response return (False, None, None, 0), invalid_response
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority) event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.cluster_backend), self.token_priority)
else: else:
event = None event = None
@ -160,7 +164,7 @@ class RequestHandler:
else: else:
error_msg = error_msg.strip('.') + '.' error_msg = error_msg.strip('.') + '.'
backend_response = self.handle_error(error_msg) backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
return (False, None, None, 0), backend_response return (False, None, None, 0), backend_response
# =============================================== # ===============================================
@ -180,7 +184,7 @@ class RequestHandler:
if return_json_err: if return_json_err:
error_msg = 'The backend did not return valid JSON.' error_msg = 'The backend did not return valid JSON.'
backend_response = self.handle_error(error_msg) backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
return (False, None, None, 0), backend_response return (False, None, None, 0), backend_response
# =============================================== # ===============================================
@ -214,10 +218,10 @@ class RequestHandler:
raise NotImplementedError raise NotImplementedError
def get_backend(): def get_backend_handler(mode):
if opts.mode == 'oobabooga': if mode == 'oobabooga':
return OobaboogaBackend() return OobaboogaBackend()
elif opts.mode == 'vllm': elif mode == 'vllm':
return VLLMBackend() return VLLMBackend()
else: else:
raise Exception raise Exception

View File

@ -2,32 +2,9 @@ from datetime import datetime
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
# proompters_5_min = 0
# concurrent_semaphore = Semaphore(concurrent_gens)
server_start_time = datetime.now() server_start_time = datetime.now()
# TODO: do I need this?
# def elapsed_times_cleanup():
# global wait_in_queue_elapsed
# while True:
# current_time = time.time()
# with wait_in_queue_elapsed_lock:
# global wait_in_queue_elapsed
# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
# time.sleep(1)
def calculate_avg_gen_time():
# Get the average generation time from Redis
average_generation_time = redis.get('average_generation_time')
if average_generation_time is None:
return 0
else:
return float(average_generation_time)
def get_total_proompts(): def get_total_proompts():
count = redis.get('proompts') count = redis.get('proompts')
if count is None: if count is None:

View File

@ -2,11 +2,12 @@ import time
from datetime import datetime from datetime import datetime
from llm_server import opts from llm_server import opts
from llm_server.cluster.backend import get_a_cluster_backend, test_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.database.database import get_distinct_ips_24h, sum_column from llm_server.database.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort, round_up_base from llm_server.helpers import deep_sort, round_up_base
from llm_server.llm.info import get_running_model from llm_server.llm.info import get_running_model
from llm_server.netdata import get_power_states
from llm_server.custom_redis import redis
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
@ -33,52 +34,43 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act
return gen_time_calc return gen_time_calc
# TODO: have routes/__init__.py point to the latest API version generate_stats()
def generate_stats(regen: bool = False): def generate_stats(regen: bool = False):
if not regen: if not regen:
c = redis.get('proxy_stats', dict) c = redis.get('proxy_stats', dtype=dict)
if c: if c:
return c return c
model_name, error = get_running_model() # will return False when the fetch fails default_backend_url = get_a_cluster_backend()
if isinstance(model_name, bool): default_backend_info = cluster_config.get_backend(default_backend_url)
online = False if not default_backend_info.get('mode'):
else: # TODO: remove
online = True print('DAEMON NOT FINISHED STARTING')
redis.set('running_model', model_name) return
base_client_api = redis.get('base_client_api', dtype=str)
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
average_generation_elapsed_sec = redis.get('average_generation_elapsed_sec', 0)
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change online = test_backend(default_backend_url, default_backend_info['mode'])
# if len(t) == 0: if online:
# estimated_wait = 0 running_model, err = get_running_model(default_backend_url, default_backend_info['mode'])
# else: cluster_config.set_backend_value(default_backend_url, 'running_model', running_model)
# waits = [elapsed for end, elapsed in t] else:
# estimated_wait = int(sum(waits) / len(waits)) running_model = None
active_gen_workers = get_active_gen_workers() active_gen_workers = get_active_gen_workers()
proompters_in_queue = len(priority_queue) proompters_in_queue = len(priority_queue)
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM. # This is so wildly inaccurate it's disabled.
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0) # estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0) # TODO: make this for the currently selected backend
estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers) estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
if opts.netdata_root:
netdata_stats = {}
power_states = get_power_states()
for gpu, power_state in power_states.items():
netdata_stats[gpu] = {
'power_state': power_state,
# 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
}
else:
netdata_stats = {}
base_client_api = redis.get('base_client_api', str)
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
output = { output = {
'default': {
'model': running_model,
'backend': default_backend_info['hash'],
},
'stats': { 'stats': {
'proompters': { 'proompters': {
'5_min': proompters_5_min, '5_min': proompters_5_min,
@ -86,9 +78,10 @@ def generate_stats(regen: bool = False):
}, },
'proompts_total': get_total_proompts() if opts.show_num_prompts else None, 'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
'average_generation_elapsed_sec': int(average_generation_time), 'average_generation_elapsed_sec': int(average_generation_elapsed_sec),
# 'estimated_avg_tps': estimated_avg_tps, # 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None, 'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
'num_backends': len(cluster_config.all()) if opts.show_backends else None,
}, },
'online': online, 'online': online,
'endpoints': { 'endpoints': {
@ -103,10 +96,7 @@ def generate_stats(regen: bool = False):
'timestamp': int(time.time()), 'timestamp': int(time.time()),
'config': { 'config': {
'gatekeeper': 'none' if opts.auth_required is False else 'token', 'gatekeeper': 'none' if opts.auth_required is False else 'token',
'context_size': opts.context_size,
'concurrent': opts.concurrent_gens, 'concurrent': opts.concurrent_gens,
'model': opts.manual_model_name if opts.manual_model_name else model_name,
'mode': opts.mode,
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip, 'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
}, },
'keys': { 'keys': {
@ -114,8 +104,41 @@ def generate_stats(regen: bool = False):
'anthropicKeys': '', 'anthropicKeys': '',
}, },
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None, 'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
'nvidia': netdata_stats
} }
if opts.show_backends:
for backend_url, v in cluster_config.all().items():
backend_info = cluster_config.get_backend(backend_url)
if not backend_info['online']:
continue
# TODO: have this fetch the data from VLLM which will display GPU utalization
# if opts.netdata_root:
# netdata_stats = {}
# power_states = get_power_states()
# for gpu, power_state in power_states.items():
# netdata_stats[gpu] = {
# 'power_state': power_state,
# # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
# }
# else:
# netdata_stats = {}
netdata_stats = {}
# TODO: use value returned by VLLM backend here
# backend_uptime = int((datetime.now() - backend_info['start_time']).total_seconds()) if opts.show_uptime else None
backend_uptime = -1
output['backend_info'][backend_info['hash']] = {
'uptime': backend_uptime,
# 'context_size': opts.context_size,
'model': opts.manual_model_name if opts.manual_model_name else backend_info.get('running_model', 'ERROR'),
'mode': backend_info['mode'],
'nvidia': netdata_stats
}
else:
output['backend_info'] = {}
result = deep_sort(output) result = deep_sort(output)
# It may take a bit to get the base client API, so don't cache until then. # It may take a bit to get the base client API, so don't cache until then.

View File

@ -1,5 +1,4 @@
import json import json
import threading
import time import time
import traceback import traceback
from typing import Union from typing import Union
@ -10,10 +9,11 @@ from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts from ... import opts
from ...cluster.backend import get_a_cluster_backend
from ...database.database import log_prompt from ...database.database import log_prompt
from ...llm.generator import generator from ...llm.generator import generator
from ...llm.vllm import tokenize from ...llm.vllm import tokenize
from ...stream import sock from ...sock import sock
# TODO: have workers process streaming requests # TODO: have workers process streaming requests
@ -35,19 +35,13 @@ def stream(ws):
log_in_bg(quitting_err_msg, is_error=True) log_in_bg(quitting_err_msg, is_error=True)
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None): def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
generated_tokens = tokenize(generated_text_bg)
def background_task_exception(): log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, cluster_backend, response_tokens=generated_tokens, is_error=is_error)
generated_tokens = tokenize(generated_text_bg)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
if not opts.enable_streaming: if not opts.enable_streaming:
return 'Streaming is disabled', 401 return 'Streaming is disabled', 401
cluster_backend = None
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
message_num = 0 message_num = 0
@ -90,14 +84,15 @@ def stream(ws):
} }
# Add a dummy event to the queue and wait for it to reach a worker # Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority) event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority)
if not event: if not event:
r, _ = handler.handle_ratelimited() r, _ = handler.handle_ratelimited()
err_msg = r.json['results'][0]['text'] err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg) send_err_and_quit(err_msg)
return return
try: try:
response = generator(llm_request) cluster_backend = get_a_cluster_backend()
response = generator(llm_request, cluster_backend)
if not response: if not response:
error_msg = 'Failed to reach backend while streaming.' error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg) print('Streaming failed:', error_msg)
@ -142,7 +137,7 @@ def stream(ws):
ws.close() ws.close()
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
return return
message_num += 1 message_num += 1
@ -181,5 +176,5 @@ def stream(ws):
# The client closed the stream. # The client closed the stream.
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text)) log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
ws.close() # this is important if we encountered and error and exited early. ws.close() # this is important if we encountered and error and exited early.

View File

@ -2,22 +2,21 @@ import time
from flask import jsonify, request from flask import jsonify, request
from llm_server.custom_redis import flask_cache
from . import bp from . import bp
from ..auth import requires_auth from ..auth import requires_auth
from llm_server.custom_redis import flask_cache
from ... import opts from ... import opts
from ...llm.info import get_running_model from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends_from_model, is_valid_model
from ...cluster.cluster_config import cluster_config
# @bp.route('/info', methods=['GET'])
# # @cache.cached(timeout=3600, query_string=True)
# def get_info():
# # requests.get()
# return 'yes'
@bp.route('/model', methods=['GET']) @bp.route('/model', methods=['GET'])
def get_model(): @bp.route('/<model_name>/model', methods=['GET'])
def get_model(model_name=None):
if not model_name:
b = get_a_cluster_backend()
model_name = cluster_config.get_backend(b)['running_model']
# We will manage caching ourself since we don't want to cache # We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors. # when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url cache_key = 'model_cache::' + request.url
@ -26,16 +25,17 @@ def get_model():
if cached_response: if cached_response:
return cached_response return cached_response
model_name, error = get_running_model() if not is_valid_model(model_name):
if not model_name:
response = jsonify({ response = jsonify({
'code': 502, 'code': 400,
'msg': 'failed to reach backend', 'msg': 'Model does not exist.',
'type': error.__class__.__name__ }), 400
}), 500 # return 500 so Cloudflare doesn't intercept us
else: else:
num_backends = len(get_backends_from_model(model_name))
response = jsonify({ response = jsonify({
'result': opts.manual_model_name if opts.manual_model_name else model_name, 'result': opts.manual_model_name if opts.manual_model_name else model_name,
'model_backend_count': num_backends,
'timestamp': int(time.time()) 'timestamp': int(time.time())
}), 200 }), 200
flask_cache.set(cache_key, response, timeout=60) flask_cache.set(cache_key, response, timeout=60)
@ -43,7 +43,11 @@ def get_model():
return response return response
@bp.route('/backend', methods=['GET']) @bp.route('/backends', methods=['GET'])
@requires_auth @requires_auth
def get_backend(): def get_backend():
return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200 online, offline = get_backends()
result = []
for i in online + offline:
result.append(cluster_config.get_backend(i))
return jsonify(result), 200

View File

@ -1,35 +0,0 @@
from threading import Thread
from .blocking import start_workers
from .main import main_background_thread
from .moderator import start_moderation_workers
from .printer import console_printer
from .recent import recent_prompters_thread
from .threads import cache_stats
from .. import opts
def start_background():
start_workers(opts.concurrent_gens)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
print('Started the main background thread.')
start_moderation_workers(opts.openai_moderation_workers)
t = Thread(target=cache_stats)
t.daemon = True
t.start()
print('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
print('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
print('Started the console printer.')

View File

@ -2,15 +2,15 @@ import threading
import time import time
from llm_server import opts from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.llm.generator import generator
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
def worker(): def worker():
while True: while True:
need_to_wait() need_to_wait()
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get() (request_json_body, client_ip, token, parameters, cluster_backend), event_id = priority_queue.get()
need_to_wait() need_to_wait()
increment_ip_count(client_ip, 'processing_ips') increment_ip_count(client_ip, 'processing_ips')
@ -22,7 +22,7 @@ def worker():
continue continue
try: try:
success, response, error_msg = generator(request_json_body) success, response, error_msg = generator(request_json_body, cluster_backend)
event = DataEvent(event_id) event = DataEvent(event_id)
event.set((success, response, error_msg)) event.set((success, response, error_msg))
finally: finally:
@ -42,7 +42,7 @@ def start_workers(num_workers: int):
def need_to_wait(): def need_to_wait():
# We need to check the number of active workers since the streaming endpoint may be doing something. # We need to check the number of active workers since the streaming endpoint may be doing something.
active_workers = redis.get('active_gen_workers', int, 0) active_workers = redis.get('active_gen_workers', 0, dtype=int)
s = time.time() s = time.time()
while active_workers >= opts.concurrent_gens: while active_workers >= opts.concurrent_gens:
time.sleep(0.01) time.sleep(0.01)

View File

@ -1,55 +0,0 @@
import time
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.custom_redis import redis
def main_background_thread():
redis.set('average_generation_elapsed_sec', 0)
redis.set('estimated_avg_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
while True:
# TODO: unify this
if opts.mode == 'oobabooga':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
elif opts.mode == 'vllm':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
redis.set('estimated_avg_tps', estimated_avg_tps)
time.sleep(60)

View File

@ -0,0 +1,56 @@
import time
from llm_server import opts
from llm_server.cluster.backend import get_a_cluster_backend, get_backends
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
def main_background_thread():
while True:
online, offline = get_backends()
for backend_url in online:
backend_info = cluster_config.get_backend(backend_url)
backend_mode = backend_info['mode']
running_model, err = get_running_model(backend_url, backend_mode)
if err:
continue
average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode)
if average_generation_elapsed_sec: # returns None on exception
cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec)
if average_output_tokens:
cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens)
if average_generation_elapsed_sec and average_output_tokens:
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
default_backend_url = get_a_cluster_backend()
default_backend_info = cluster_config.get_backend(default_backend_url)
default_backend_mode = default_backend_info['mode']
default_running_model, err = get_running_model(default_backend_url, default_backend_mode)
if err:
continue
default_average_generation_elapsed_sec, default_average_output_tokens, default_estimated_avg_tps = calc_stats_for_backend(default_running_model, default_running_model, default_backend_mode)
if default_average_generation_elapsed_sec:
redis.set('average_generation_elapsed_sec', default_average_generation_elapsed_sec)
if default_average_output_tokens:
redis.set('average_output_tokens', default_average_output_tokens)
if default_average_generation_elapsed_sec and default_average_output_tokens:
redis.set('estimated_avg_tps', default_estimated_avg_tps)
time.sleep(30)
def calc_stats_for_backend(backend_url, running_model, backend_mode):
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps

View File

@ -0,0 +1,50 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.cluster.stores import redis_running_models
from llm_server.cluster.worker import cluster_worker
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers
from llm_server.workers.mainer import main_background_thread
from llm_server.workers.moderator import start_moderation_workers
from llm_server.workers.printer import console_printer
from llm_server.workers.recenter import recent_prompters_thread
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(1)
def start_background():
start_workers(opts.concurrent_gens)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
print('Started the main background thread.')
start_moderation_workers(opts.openai_moderation_workers)
t = Thread(target=cache_stats)
t.daemon = True
t.start()
print('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
print('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
print('Started the console printer.')
redis_running_models.flush()
t = Thread(target=cluster_worker)
t.daemon = True
t.start()
print('Started the cluster worker.')

View File

@ -1,9 +0,0 @@
import time
from llm_server.routes.v1.generate_stats import generate_stats
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)

View File

@ -1,3 +1,8 @@
"""
This file is used to run certain tasks when the HTTP server starts.
It's located here so it doesn't get imported with daemon.py
"""
try: try:
import gevent.monkey import gevent.monkey

View File

@ -1,4 +1,4 @@
from llm_server.config.config import mode_ui_names from llm_server.cluster.cluster_config import cluster_config
try: try:
import gevent.monkey import gevent.monkey
@ -7,8 +7,6 @@ try:
except ImportError: except ImportError:
pass pass
from llm_server.pre_fork import server_startup
from llm_server.config.load import load_config, parse_backends
import os import os
import sys import sys
from pathlib import Path from pathlib import Path
@ -16,14 +14,17 @@ from pathlib import Path
import simplejson as json import simplejson as json
from flask import Flask, jsonify, render_template, request from flask import Flask, jsonify, render_template, request
import llm_server from llm_server.cluster.backend import get_a_cluster_backend, get_backends
from llm_server.cluster.redis_cycle import load_backend_cycle
from llm_server.config.config import mode_ui_names
from llm_server.config.load import load_config, parse_backends
from llm_server.database.conn import database from llm_server.database.conn import database
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.llm import get_token_count from llm_server.pre_fork import server_startup
from llm_server.routes.openai import openai_bp from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio from llm_server.sock import init_socketio
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation. # TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail # TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
@ -37,6 +38,8 @@ from llm_server.stream import init_socketio
# TODO: use coloredlogs # TODO: use coloredlogs
# TODO: need to update opts. for workers # TODO: need to update opts. for workers
# TODO: add a healthcheck to VLLM # TODO: add a healthcheck to VLLM
# TODO: allow choosing the model by the URL path
# TODO: have VLLM report context size, uptime
# Lower priority # Lower priority
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens # TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
@ -64,7 +67,7 @@ import config
from llm_server import opts from llm_server import opts
from llm_server.helpers import auto_set_base_client_api from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info from llm_server.llm.vllm.info import vllm_info
from llm_server.custom_redis import RedisCustom, flask_cache from llm_server.custom_redis import flask_cache
from llm_server.llm import redis from llm_server.llm import redis
from llm_server.routes.stats import get_active_gen_workers from llm_server.routes.stats import get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
@ -83,20 +86,18 @@ if config_path_environ:
else: else:
config_path = Path(script_path, 'config', 'config.yml') config_path = Path(script_path, 'config', 'config.yml')
success, config, msg = load_config(config_path, script_path) success, config, msg = load_config(config_path)
if not success: if not success:
print('Failed to load config:', msg) print('Failed to load config:', msg)
sys.exit(1) sys.exit(1)
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database']) database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db() create_db()
llm_server.llm.redis = RedisCustom('local_llm')
create_db()
x = parse_backends(config) cluster_config.clear()
print(x) cluster_config.load(parse_backends(config))
on, off = get_backends()
# print(app.url_map) load_backend_cycle('backend_cycler', on + off)
@app.route('/') @app.route('/')
@ -104,12 +105,18 @@ print(x)
@app.route('/api/openai') @app.route('/api/openai')
@flask_cache.cached(timeout=10) @flask_cache.cached(timeout=10)
def home(): def home():
stats = generate_stats() # Use the default backend
backend_url = get_a_cluster_backend()
if backend_url:
backend_info = cluster_config.get_backend(backend_url)
stats = generate_stats(backend_url)
else:
backend_info = stats = None
if not stats['online']: if not stats['online']:
running_model = estimated_wait_sec = 'offline' running_model = estimated_wait_sec = 'offline'
else: else:
running_model = redis.get('running_model', str, 'ERROR') running_model = backend_info['running_model']
active_gen_workers = get_active_gen_workers() active_gen_workers = get_active_gen_workers()
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens: if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
@ -130,10 +137,16 @@ def home():
info_html = '' info_html = ''
mode_info = '' mode_info = ''
if opts.mode == 'vllm': using_vllm = False
for k, v in cluster_config.all().items():
if v['mode'] == vllm:
using_vllm = True
break
if using_vllm == 'vllm':
mode_info = vllm_info mode_info = vllm_info
base_client_api = redis.get('base_client_api', str) base_client_api = redis.get('base_client_api', dtype=str)
return render_template('home.html', return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name, llm_middleware_name=opts.llm_middleware_name,

View File

@ -7,23 +7,33 @@ except ImportError:
import time import time
from threading import Thread from threading import Thread
from llm_server.cluster.redis_cycle import load_backend_cycle
from llm_server.cluster.funcs.backend import get_best_backends from llm_server.cluster.backend import get_backends, get_a_cluster_backend
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.cluster.worker import cluster_worker from llm_server.cluster.worker import cluster_worker
from llm_server.config.load import parse_backends, load_config from llm_server.config.load import parse_backends, load_config
from llm_server.cluster.redis_config_cache import RedisClusterStore
success, config, msg = load_config('./config/config.yml').resolve().absolute() import argparse
parser = argparse.ArgumentParser()
parser.add_argument('config')
args = parser.parse_args()
success, config, msg = load_config(args.config)
cluster_config = RedisClusterStore('cluster_config') cluster_config = RedisClusterStore('cluster_config')
cluster_config.clear() cluster_config.clear()
cluster_config.load(parse_backends(config)) cluster_config.load(parse_backends(config))
on, off = get_backends()
load_backend_cycle('backend_cycler', on + off)
t = Thread(target=cluster_worker) t = Thread(target=cluster_worker)
t.daemon = True t.daemon = True
t.start() t.start()
while True: while True:
x = get_best_backends() # online, offline = get_backends()
print(x) # print(online, offline)
# print(get_a_cluster_backend())
time.sleep(3) time.sleep(3)