This commit is contained in:
Cyberes 2023-09-29 00:09:44 -06:00
parent e7b57cad7b
commit 624ca74ce5
47 changed files with 506 additions and 390 deletions

View File

@ -1,22 +1,12 @@
import time
from llm_server.custom_redis import redis
try:
import gevent.monkey
gevent.monkey.patch_all()
except ImportError:
pass
import os
import sys
import time
from pathlib import Path
from llm_server.config.load import load_config
from llm_server.custom_redis import redis
from llm_server.database.create import create_db
from llm_server.workers.app import start_background
from llm_server.workers.threader import start_background
script_path = os.path.dirname(os.path.realpath(__file__))
config_path_environ = os.getenv("CONFIG_PATH")
@ -29,7 +19,7 @@ if __name__ == "__main__":
flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.')
success, config, msg = load_config(config_path, script_path)
success, config, msg = load_config(config_path)
if not success:
print('Failed to load config:', msg)
sys.exit(1)

View File

@ -0,0 +1,71 @@
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.cluster.redis_cycle import redis_cycle
from llm_server.cluster.stores import redis_running_models
from llm_server.llm.info import get_running_model
def test_backend(backend_url: str, mode: str):
running_model, err = get_running_model(backend_url, mode)
if not running_model:
return False
return True
def get_backends():
cluster_config = RedisClusterStore('cluster_config')
backends = cluster_config.all()
result = {}
for k, v in backends.items():
b = cluster_config.get_backend(k)
status = b['online']
priority = b['priority']
result[k] = {'status': status, 'priority': priority}
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: -kv[1]['priority'],
reverse=True
)
offline_backends = sorted(
((url, info) for url, info in backends.items() if not info['online']),
key=lambda kv: -kv[1]['priority'],
reverse=True
)
return [url for url, info in online_backends], [url for url, info in offline_backends]
def get_a_cluster_backend():
"""
Get a backend from Redis. If there are no online backends, return None.
"""
online, offline = get_backends()
cycled = redis_cycle('backend_cycler')
c = cycled.copy()
for i in range(len(cycled)):
if cycled[i] in offline:
del c[c.index(cycled[i])]
if len(c):
return c[0]
else:
return None
def get_backends_from_model(model_name: str):
cluster_config = RedisClusterStore('cluster_config')
a = cluster_config.all()
matches = []
for k, v in a.items():
if v['online'] and v['running_model'] == model_name:
matches.append(k)
return matches
def purge_backend_from_running_models(backend_url: str):
keys = redis_running_models.keys()
pipeline = redis_running_models.pipeline()
for model in keys:
pipeline.srem(model, backend_url)
pipeline.execute()
def is_valid_model(model_name: str):
return redis_running_models.exists(model_name)

View File

@ -0,0 +1,3 @@
from llm_server.cluster.redis_config_cache import RedisClusterStore
cluster_config = RedisClusterStore('cluster_config')

View File

@ -1,26 +0,0 @@
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.llm.info import get_running_model
def test_backend(backend_url: str):
running_model, err = get_running_model(backend_url)
if not running_model:
return False
return True
def get_best_backends():
cluster_config = RedisClusterStore('cluster_config')
backends = cluster_config.all()
result = {}
for k, v in backends.items():
b = cluster_config.get_backend(k)
status = b['online']
priority = b['priority']
result[k] = {'status': status, 'priority': priority}
online_backends = sorted(
((url, info) for url, info in backends.items() if info['online']),
key=lambda kv: kv[1]['priority'],
reverse=True
)
return [url for url, info in online_backends]

View File

@ -1,3 +1,4 @@
import hashlib
import pickle
from llm_server.custom_redis import RedisCustom
@ -13,14 +14,17 @@ class RedisClusterStore:
def load(self, config: dict):
for k, v in config.items():
self.set_backend(k, v)
self.add_backend(k, v)
def set_backend(self, name: str, values: dict):
def add_backend(self, name: str, values: dict):
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
self.set_backend_value(name, 'online', False)
h = hashlib.sha256(name.encode('utf-8')).hexdigest()
self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}')
def set_backend_value(self, key: str, name: str, value):
self.config_redis.hset(key, name, pickle.dumps(value))
def set_backend_value(self, backend: str, key: str, value):
# By storing the value as a pickle we don't have to cast anything when getting the value from Redis.
self.config_redis.hset(backend, key, pickle.dumps(value))
def get_backend(self, name: str):
r = self.config_redis.hgetall(name)

View File

@ -0,0 +1,21 @@
import redis
r = redis.Redis(host='localhost', port=6379, db=9)
def redis_cycle(list_name):
while True:
pipe = r.pipeline()
pipe.lpop(list_name)
popped_element = pipe.execute()[0]
if popped_element is None:
return None
r.rpush(list_name, popped_element)
new_list = r.lrange(list_name, 0, -1)
return [x.decode('utf-8') for x in new_list]
def load_backend_cycle(list_name: str, elements: list):
r.delete(list_name)
for element in elements:
r.rpush(list_name, element)

View File

@ -0,0 +1,3 @@
from llm_server.custom_redis import RedisCustom
redis_running_models = RedisCustom('running_models')

View File

@ -1,10 +1,10 @@
import time
from datetime import datetime
from threading import Thread
from llm_server.cluster.funcs.backend import test_backend
from llm_server.cluster.redis_config_cache import RedisClusterStore
cluster_config = RedisClusterStore('cluster_config')
from llm_server.cluster.backend import purge_backend_from_running_models, test_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.stores import redis_running_models
from llm_server.llm.info import get_running_model
def cluster_worker():
@ -16,10 +16,16 @@ def cluster_worker():
threads.append(thread)
for thread in threads:
thread.join()
time.sleep(10)
def check_backend(n, v):
# Check if backends are online
online = test_backend(v['backend_url'])
# TODO: also have test_backend() get the uptime
online = test_backend(v['backend_url'], v['mode'])
if online:
running_model, err = get_running_model(v['backend_url'], v['mode'])
if not err:
cluster_config.set_backend_value(n, 'running_model', running_model)
purge_backend_from_running_models(n)
redis_running_models.sadd(running_model, n)
cluster_config.set_backend_value(n, 'online', online)

View File

@ -32,7 +32,8 @@ config_default_vars = {
'openai_org_name': 'OpenAI',
'openai_silent_trim': False,
'openai_moderation_enabled': True,
'netdata_root': None
'netdata_root': None,
'show_backends': True,
}
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -52,6 +52,7 @@ def load_config(config_path):
opts.openai_org_name = config['openai_org_name']
opts.openai_silent_trim = config['openai_silent_trim']
opts.openai_moderation_enabled = config['openai_moderation_enabled']
opts.show_backends = config['show_backends']
if opts.openai_expose_our_model and not opts.openai_api_key:
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')

View File

@ -1,13 +1,13 @@
import pickle
import sys
import traceback
from typing import Callable, List, Mapping, Union, Optional
from typing import Callable, List, Mapping, Optional, Union
import redis as redis_pkg
import simplejson as json
from flask_caching import Cache
from redis import Redis
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT, PatternT
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
@ -35,12 +35,12 @@ class RedisCustom:
def set(self, key, value, ex: Union[ExpiryT, None] = None):
return self.redis.set(self._key(key), value, ex=ex)
def get(self, key, dtype=None, default=None):
"""
:param key:
:param dtype: convert to this type
:return:
"""
def get(self, key, default=None, dtype=None):
# TODO: use pickle
import inspect
if inspect.isclass(default):
raise Exception
d = self.redis.get(self._key(key))
if dtype and d:
try:
@ -153,11 +153,23 @@ class RedisCustom:
keys = []
for key in raw_keys:
p = key.decode('utf-8').split(':')
if len(p) > 2:
if len(p) >= 2:
# Delete prefix
del p[0]
keys.append(':'.join(p))
k = ':'.join(p)
if k != '____':
keys.append(k)
return keys
def pipeline(self, transaction=True, shard_hint=None):
return self.redis.pipeline(transaction, shard_hint)
def exists(self, *names: KeyT):
n = []
for name in names:
n.append(self._key(name))
return self.redis.exists(*n)
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex)
@ -174,7 +186,7 @@ class RedisCustom:
def getp(self, name: str):
r = self.redis.get(name)
if r:
return pickle.load(r)
return pickle.loads(r)
return r
def flush(self):

View File

@ -1,60 +1,69 @@
import json
import time
import traceback
from threading import Thread
import llm_server
from llm_server import opts
from llm_server.custom_redis import redis
from llm_server.database.conn import database
from llm_server.llm.vllm import tokenize
from llm_server.custom_redis import redis
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
if isinstance(response, dict) and response.get('results'):
response = response['results'][0]['text']
try:
j = json.loads(response)
if j.get('results'):
response = j['results'][0]['text']
except:
pass
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens: int = None, is_error: bool = False):
def background_task():
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens, is_error
# Try not to shove JSON into the database.
if isinstance(response, dict) and response.get('results'):
response = response['results'][0]['text']
try:
j = json.loads(response)
if j.get('results'):
response = j['results'][0]['text']
except:
pass
prompt_tokens = llm_server.llm.get_token_count(prompt)
if not is_error:
if not response_tokens:
response_tokens = llm_server.llm.get_token_count(response)
else:
response_tokens = None
prompt_tokens = llm_server.llm.get_token_count(prompt)
if not is_error:
if not response_tokens:
response_tokens = llm_server.llm.get_token_count(response)
else:
response_tokens = None
# Sometimes we may want to insert null into the DB, but
# usually we want to insert a float.
if gen_time:
gen_time = round(gen_time, 3)
if is_error:
gen_time = None
# Sometimes we may want to insert null into the DB, but
# usually we want to insert a float.
if gen_time:
gen_time = round(gen_time, 3)
if is_error:
gen_time = None
if not opts.log_prompts:
prompt = None
if not opts.log_prompts:
prompt = None
if not opts.log_prompts and not is_error:
# TODO: test and verify this works as expected
response = None
if not opts.log_prompts and not is_error:
# TODO: test and verify this works as expected
response = None
if token:
increment_token_uses(token)
if token:
increment_token_uses(token)
running_model = redis.get('running_model', str, 'ERROR')
timestamp = int(time.time())
cursor = database.cursor()
try:
cursor.execute("""
INSERT INTO prompts
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()
running_model = redis.get('running_model', str, 'ERROR')
timestamp = int(time.time())
cursor = database.cursor()
try:
cursor.execute("""
INSERT INTO prompts
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
(ip, token, running_model, opts.mode, cluster_backend, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
finally:
cursor.close()
# TODO: use async/await instead of threads
thread = Thread(target=background_task)
thread.start()
thread.join()
def is_valid_api_key(api_key):

View File

@ -60,7 +60,7 @@ def round_up_base(n, base):
def auto_set_base_client_api(request):
http_host = redis.get('http_host', str)
http_host = redis.get('http_host', dtype=str)
host = request.headers.get("Host")
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
# If the current http_host is not an IP, don't do anything.

View File

@ -1,12 +0,0 @@
import threading
class ThreadSafeInteger:
def __init__(self, value=0):
self.value = value
self._value_lock = threading.Lock()
def increment(self):
with self._value_lock:
self.value += 1
return self.value

View File

@ -3,7 +3,7 @@ from llm_server.custom_redis import redis
def get_token_count(prompt: str):
backend_mode = redis.get('backend_mode', str)
backend_mode = redis.get('backend_mode', dtype=str)
if backend_mode == 'vllm':
return vllm.tokenize(prompt)
elif backend_mode == 'ooba':

View File

@ -1,14 +1,14 @@
from llm_server import opts
def generator(request_json_body):
def generator(request_json_body, cluster_backend):
if opts.mode == 'oobabooga':
# from .oobabooga.generate import generate
# return generate(request_json_body)
raise NotImplementedError
elif opts.mode == 'vllm':
from .vllm.generate import generate
r = generate(request_json_body)
r = generate(request_json_body, cluster_backend)
return r
else:
raise Exception

View File

@ -3,19 +3,15 @@ import requests
from llm_server import opts
def get_running_model(backend_url: str):
# TODO: remove this once we go to Redis
if not backend_url:
backend_url = opts.backend_url
if opts.mode == 'oobabooga':
def get_running_model(backend_url: str, mode: str):
if mode == 'ooba':
try:
backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json()
return r_json['result'], None
except Exception as e:
return False, e
elif opts.mode == 'vllm':
elif mode == 'vllm':
try:
backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
r_json = backend_response.json()

View File

@ -40,6 +40,6 @@ class LLMBackend:
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
prompt_len = get_token_count(prompt)
if prompt_len > opts.context_size - 10:
model_name = redis.get('running_model', str, 'NO MODEL ERROR')
model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str)
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
return True, None

View File

@ -34,7 +34,7 @@ def build_openai_response(prompt, response, model=None):
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)
response_tokens = llm_server.llm.get_token_count(response)
running_model = redis.get('running_model', str, 'ERROR')
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
"id": f"chatcmpl-{generate_oai_string(30)}",
@ -57,7 +57,7 @@ def build_openai_response(prompt, response, model=None):
}
}), 200)
stats = redis.get('proxy_stats', dict)
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response

View File

@ -49,7 +49,7 @@ def transform_to_text(json_request, api_response):
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
completion_tokens = len(llm_server.llm.get_token_count(text))
running_model = redis.get('running_model', str, 'ERROR')
running_model = redis.get('running_model', 'ERROR', dtype=str)
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
return {
@ -82,9 +82,9 @@ def transform_prompt_to_text(prompt: list):
return text.strip('\n')
def handle_blocking_request(json_data: dict):
def handle_blocking_request(json_data: dict, cluster_backend):
try:
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
except requests.exceptions.ReadTimeout:
print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
return False, None, 'Request to backend timed out'
@ -97,11 +97,11 @@ def handle_blocking_request(json_data: dict):
return True, r, None
def generate(json_data: dict):
def generate(json_data: dict, cluster_backend):
if json_data.get('stream'):
try:
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
except Exception as e:
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
else:
return handle_blocking_request(json_data)
return handle_blocking_request(json_data, cluster_backend)

View File

@ -19,16 +19,8 @@ class VLLMBackend(LLMBackend):
# Failsafe
backend_response = ''
r_url = request.url
def background_task():
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)
thread.start()
thread.join()
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
return jsonify({'results': [{'text': backend_response}]}), 200

View File

@ -2,7 +2,6 @@
# TODO: rewrite the config system so I don't have to add every single config default here
running_model = 'ERROR'
concurrent_gens = 3
mode = 'oobabooga'
backend_url = None
@ -38,3 +37,4 @@ openai_org_name = 'OpenAI'
openai_silent_trim = False
openai_moderation_enabled = True
cluster = {}
show_backends = True

View File

@ -7,7 +7,7 @@ from llm_server.routes.v1.generate_stats import generate_stats
def server_startup(s):
if not redis.get('daemon_started', bool):
if not redis.get('daemon_started', dtype=bool):
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
sys.exit(1)

View File

@ -1,10 +1,14 @@
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
def format_sillytavern_err(msg: str, level: str = 'info'):
http_host = redis.get('http_host', str)
def format_sillytavern_err(msg: str, backend_url: str, level: str = 'info'):
cluster_backend_hash = cluster_config.get_backend_handler(backend_url)['hash']
http_host = redis.get('http_host', dtype=str)
return f"""```
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
-> {level.upper()} <-
{msg}
BACKEND HASH: {cluster_backend_hash}
```"""

View File

@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg)
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.cluster_backend, is_error=True)
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
@ -40,7 +40,7 @@ class OobaRequestHandler(RequestHandler):
# TODO: how to format this
response_msg = error_msg
else:
response_msg = format_sillytavern_err(error_msg, error_type)
response_msg = format_sillytavern_err(error_msg, error_type, self.cluster_backend)
return jsonify({
'results': [{'text': response_msg}]

View File

@ -5,11 +5,12 @@ import traceback
from flask import Response, jsonify, request
from . import openai_bp
from llm_server.custom_redis import redis
from . import openai_bp
from ..helpers.http import validate_json
from ..openai_request_handler import OpenAIRequestHandler
from ... import opts
from ...cluster.backend import get_a_cluster_backend
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
@ -48,10 +49,11 @@ def openai_chat_completions():
'stream': True,
}
try:
response = generator(msg_to_backend)
cluster_backend = get_a_cluster_backend()
response = generator(msg_to_backend, cluster_backend)
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
def generate():
@ -94,7 +96,7 @@ def openai_chat_completions():
def background_task():
generated_tokens = tokenize(generated_text)
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=generated_tokens)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task)

View File

@ -29,7 +29,7 @@ def openai_completions():
# TODO: async/await
prompt_tokens = get_token_count(request_json_body['prompt'])
response_tokens = get_token_count(output)
running_model = redis.get('running_model', str, 'ERROR')
running_model = redis.get('running_model', 'ERROR', dtype=str)
response = make_response(jsonify({
"id": f"cmpl-{generate_oai_string(30)}",
@ -51,7 +51,7 @@ def openai_completions():
}
}), 200)
stats = redis.get('proxy_stats', dict)
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response

View File

@ -7,6 +7,7 @@ from . import openai_bp
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
from ..stats import server_start_time
from ... import opts
from ...cluster.backend import get_a_cluster_backend
from ...helpers import jsonify_pretty
from ...llm.info import get_running_model
@ -22,7 +23,7 @@ def openai_list_models():
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
else:
running_model = redis.get('running_model', str, 'ERROR')
running_model = redis.get('running_model', 'ERROR', dtype=str)
oai = fetch_openai_models()
r = []
if opts.openai_expose_our_model:

View File

@ -93,6 +93,6 @@ def incr_active_workers():
def decr_active_workers():
redis.decr('active_gen_workers')
new_count = redis.get('active_gen_workers', int, 0)
new_count = redis.get('active_gen_workers', 0, dtype=int)
if new_count < 0:
redis.set('active_gen_workers', 0)

View File

@ -5,13 +5,15 @@ import flask
from flask import Response, request
from llm_server import opts
from llm_server.cluster.backend import get_a_cluster_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.database.conn import database
from llm_server.database.database import log_prompt
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
from llm_server.llm.vllm.vllm_backend import VLLMBackend
from llm_server.routes.auth import parse_token
from llm_server.custom_redis import redis
from llm_server.routes.helpers.http import require_api_key, validate_json
from llm_server.routes.queue import priority_queue
@ -35,7 +37,9 @@ class RequestHandler:
self.client_ip = self.get_client_ip()
self.token = self.get_auth_token()
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
self.backend = get_backend()
self.cluster_backend = get_a_cluster_backend()
self.cluster_backend_info = cluster_config.get_backend(self.cluster_backend)
self.backend = get_backend_handler(self.cluster_backend)
self.parameters = None
self.used = False
redis.zadd('recent_prompters', {self.client_ip: time.time()})
@ -119,7 +123,7 @@ class RequestHandler:
backend_response = self.handle_error(combined_error_message, 'Validation Error')
if do_log:
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.cluster_backend, is_error=True)
return False, backend_response
return True, (None, 0)
@ -131,7 +135,7 @@ class RequestHandler:
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
if not request_valid:
return (False, None, None, 0), invalid_response
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.cluster_backend), self.token_priority)
else:
event = None
@ -160,7 +164,7 @@ class RequestHandler:
else:
error_msg = error_msg.strip('.') + '.'
backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
@ -180,7 +184,7 @@ class RequestHandler:
if return_json_err:
error_msg = 'The backend did not return valid JSON.'
backend_response = self.handle_error(error_msg)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
return (False, None, None, 0), backend_response
# ===============================================
@ -214,10 +218,10 @@ class RequestHandler:
raise NotImplementedError
def get_backend():
if opts.mode == 'oobabooga':
def get_backend_handler(mode):
if mode == 'oobabooga':
return OobaboogaBackend()
elif opts.mode == 'vllm':
elif mode == 'vllm':
return VLLMBackend()
else:
raise Exception

View File

@ -2,32 +2,9 @@ from datetime import datetime
from llm_server.custom_redis import redis
# proompters_5_min = 0
# concurrent_semaphore = Semaphore(concurrent_gens)
server_start_time = datetime.now()
# TODO: do I need this?
# def elapsed_times_cleanup():
# global wait_in_queue_elapsed
# while True:
# current_time = time.time()
# with wait_in_queue_elapsed_lock:
# global wait_in_queue_elapsed
# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
# time.sleep(1)
def calculate_avg_gen_time():
# Get the average generation time from Redis
average_generation_time = redis.get('average_generation_time')
if average_generation_time is None:
return 0
else:
return float(average_generation_time)
def get_total_proompts():
count = redis.get('proompts')
if count is None:

View File

@ -2,11 +2,12 @@ import time
from datetime import datetime
from llm_server import opts
from llm_server.cluster.backend import get_a_cluster_backend, test_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.database.database import get_distinct_ips_24h, sum_column
from llm_server.helpers import deep_sort, round_up_base
from llm_server.llm.info import get_running_model
from llm_server.netdata import get_power_states
from llm_server.custom_redis import redis
from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
@ -33,52 +34,43 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act
return gen_time_calc
# TODO: have routes/__init__.py point to the latest API version generate_stats()
def generate_stats(regen: bool = False):
if not regen:
c = redis.get('proxy_stats', dict)
c = redis.get('proxy_stats', dtype=dict)
if c:
return c
model_name, error = get_running_model() # will return False when the fetch fails
if isinstance(model_name, bool):
online = False
else:
online = True
redis.set('running_model', model_name)
default_backend_url = get_a_cluster_backend()
default_backend_info = cluster_config.get_backend(default_backend_url)
if not default_backend_info.get('mode'):
# TODO: remove
print('DAEMON NOT FINISHED STARTING')
return
base_client_api = redis.get('base_client_api', dtype=str)
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
average_generation_elapsed_sec = redis.get('average_generation_elapsed_sec', 0)
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
# if len(t) == 0:
# estimated_wait = 0
# else:
# waits = [elapsed for end, elapsed in t]
# estimated_wait = int(sum(waits) / len(waits))
online = test_backend(default_backend_url, default_backend_info['mode'])
if online:
running_model, err = get_running_model(default_backend_url, default_backend_info['mode'])
cluster_config.set_backend_value(default_backend_url, 'running_model', running_model)
else:
running_model = None
active_gen_workers = get_active_gen_workers()
proompters_in_queue = len(priority_queue)
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
# This is so wildly inaccurate it's disabled.
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
if opts.netdata_root:
netdata_stats = {}
power_states = get_power_states()
for gpu, power_state in power_states.items():
netdata_stats[gpu] = {
'power_state': power_state,
# 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
}
else:
netdata_stats = {}
base_client_api = redis.get('base_client_api', str)
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
# TODO: make this for the currently selected backend
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
output = {
'default': {
'model': running_model,
'backend': default_backend_info['hash'],
},
'stats': {
'proompters': {
'5_min': proompters_5_min,
@ -86,9 +78,10 @@ def generate_stats(regen: bool = False):
},
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
'average_generation_elapsed_sec': int(average_generation_time),
'average_generation_elapsed_sec': int(average_generation_elapsed_sec),
# 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
'num_backends': len(cluster_config.all()) if opts.show_backends else None,
},
'online': online,
'endpoints': {
@ -103,10 +96,7 @@ def generate_stats(regen: bool = False):
'timestamp': int(time.time()),
'config': {
'gatekeeper': 'none' if opts.auth_required is False else 'token',
'context_size': opts.context_size,
'concurrent': opts.concurrent_gens,
'model': opts.manual_model_name if opts.manual_model_name else model_name,
'mode': opts.mode,
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
},
'keys': {
@ -114,8 +104,41 @@ def generate_stats(regen: bool = False):
'anthropicKeys': '',
},
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
'nvidia': netdata_stats
}
if opts.show_backends:
for backend_url, v in cluster_config.all().items():
backend_info = cluster_config.get_backend(backend_url)
if not backend_info['online']:
continue
# TODO: have this fetch the data from VLLM which will display GPU utalization
# if opts.netdata_root:
# netdata_stats = {}
# power_states = get_power_states()
# for gpu, power_state in power_states.items():
# netdata_stats[gpu] = {
# 'power_state': power_state,
# # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
# }
# else:
# netdata_stats = {}
netdata_stats = {}
# TODO: use value returned by VLLM backend here
# backend_uptime = int((datetime.now() - backend_info['start_time']).total_seconds()) if opts.show_uptime else None
backend_uptime = -1
output['backend_info'][backend_info['hash']] = {
'uptime': backend_uptime,
# 'context_size': opts.context_size,
'model': opts.manual_model_name if opts.manual_model_name else backend_info.get('running_model', 'ERROR'),
'mode': backend_info['mode'],
'nvidia': netdata_stats
}
else:
output['backend_info'] = {}
result = deep_sort(output)
# It may take a bit to get the base client API, so don't cache until then.

View File

@ -1,5 +1,4 @@
import json
import threading
import time
import traceback
from typing import Union
@ -10,10 +9,11 @@ from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
from ... import opts
from ...cluster.backend import get_a_cluster_backend
from ...database.database import log_prompt
from ...llm.generator import generator
from ...llm.vllm import tokenize
from ...stream import sock
from ...sock import sock
# TODO: have workers process streaming requests
@ -35,19 +35,13 @@ def stream(ws):
log_in_bg(quitting_err_msg, is_error=True)
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
def background_task_exception():
generated_tokens = tokenize(generated_text_bg)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error)
# TODO: use async/await instead of threads
thread = threading.Thread(target=background_task_exception)
thread.start()
thread.join()
generated_tokens = tokenize(generated_text_bg)
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, cluster_backend, response_tokens=generated_tokens, is_error=is_error)
if not opts.enable_streaming:
return 'Streaming is disabled', 401
cluster_backend = None
r_headers = dict(request.headers)
r_url = request.url
message_num = 0
@ -90,14 +84,15 @@ def stream(ws):
}
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority)
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority)
if not event:
r, _ = handler.handle_ratelimited()
err_msg = r.json['results'][0]['text']
send_err_and_quit(err_msg)
return
try:
response = generator(llm_request)
cluster_backend = get_a_cluster_backend()
response = generator(llm_request, cluster_backend)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
@ -142,7 +137,7 @@ def stream(ws):
ws.close()
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
return
message_num += 1
@ -181,5 +176,5 @@ def stream(ws):
# The client closed the stream.
end_time = time.time()
elapsed_time = end_time - start_time
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
ws.close() # this is important if we encountered and error and exited early.

View File

@ -2,22 +2,21 @@ import time
from flask import jsonify, request
from llm_server.custom_redis import flask_cache
from . import bp
from ..auth import requires_auth
from llm_server.custom_redis import flask_cache
from ... import opts
from ...llm.info import get_running_model
# @bp.route('/info', methods=['GET'])
# # @cache.cached(timeout=3600, query_string=True)
# def get_info():
# # requests.get()
# return 'yes'
from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends_from_model, is_valid_model
from ...cluster.cluster_config import cluster_config
@bp.route('/model', methods=['GET'])
def get_model():
@bp.route('/<model_name>/model', methods=['GET'])
def get_model(model_name=None):
if not model_name:
b = get_a_cluster_backend()
model_name = cluster_config.get_backend(b)['running_model']
# We will manage caching ourself since we don't want to cache
# when the backend is down. Also, Cloudflare won't cache 500 errors.
cache_key = 'model_cache::' + request.url
@ -26,16 +25,17 @@ def get_model():
if cached_response:
return cached_response
model_name, error = get_running_model()
if not model_name:
if not is_valid_model(model_name):
response = jsonify({
'code': 502,
'msg': 'failed to reach backend',
'type': error.__class__.__name__
}), 500 # return 500 so Cloudflare doesn't intercept us
'code': 400,
'msg': 'Model does not exist.',
}), 400
else:
num_backends = len(get_backends_from_model(model_name))
response = jsonify({
'result': opts.manual_model_name if opts.manual_model_name else model_name,
'model_backend_count': num_backends,
'timestamp': int(time.time())
}), 200
flask_cache.set(cache_key, response, timeout=60)
@ -43,7 +43,11 @@ def get_model():
return response
@bp.route('/backend', methods=['GET'])
@bp.route('/backends', methods=['GET'])
@requires_auth
def get_backend():
return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200
online, offline = get_backends()
result = []
for i in online + offline:
result.append(cluster_config.get_backend(i))
return jsonify(result), 200

View File

@ -1,35 +0,0 @@
from threading import Thread
from .blocking import start_workers
from .main import main_background_thread
from .moderator import start_moderation_workers
from .printer import console_printer
from .recent import recent_prompters_thread
from .threads import cache_stats
from .. import opts
def start_background():
start_workers(opts.concurrent_gens)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
print('Started the main background thread.')
start_moderation_workers(opts.openai_moderation_workers)
t = Thread(target=cache_stats)
t.daemon = True
t.start()
print('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
print('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
print('Started the console printer.')

View File

@ -2,15 +2,15 @@ import threading
import time
from llm_server import opts
from llm_server.llm.generator import generator
from llm_server.custom_redis import redis
from llm_server.llm.generator import generator
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
def worker():
while True:
need_to_wait()
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
(request_json_body, client_ip, token, parameters, cluster_backend), event_id = priority_queue.get()
need_to_wait()
increment_ip_count(client_ip, 'processing_ips')
@ -22,7 +22,7 @@ def worker():
continue
try:
success, response, error_msg = generator(request_json_body)
success, response, error_msg = generator(request_json_body, cluster_backend)
event = DataEvent(event_id)
event.set((success, response, error_msg))
finally:
@ -42,7 +42,7 @@ def start_workers(num_workers: int):
def need_to_wait():
# We need to check the number of active workers since the streaming endpoint may be doing something.
active_workers = redis.get('active_gen_workers', int, 0)
active_workers = redis.get('active_gen_workers', 0, dtype=int)
s = time.time()
while active_workers >= opts.concurrent_gens:
time.sleep(0.01)

View File

@ -1,55 +0,0 @@
import time
from llm_server import opts
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
from llm_server.custom_redis import redis
def main_background_thread():
redis.set('average_generation_elapsed_sec', 0)
redis.set('estimated_avg_tps', 0)
redis.set('average_output_tokens', 0)
redis.set('backend_online', 0)
redis.set_dict('backend_info', {})
while True:
# TODO: unify this
if opts.mode == 'oobabooga':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
elif opts.mode == 'vllm':
running_model, err = get_running_model()
if err:
print(err)
redis.set('backend_online', 0)
else:
redis.set('running_model', running_model)
redis.set('backend_online', 1)
else:
raise Exception
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec: # returns None on exception
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
if average_generation_elapsed_sec:
redis.set('average_output_tokens', average_output_tokens)
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
redis.set('estimated_avg_tps', estimated_avg_tps)
time.sleep(60)

View File

@ -0,0 +1,56 @@
import time
from llm_server import opts
from llm_server.cluster.backend import get_a_cluster_backend, get_backends
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_running_model
def main_background_thread():
while True:
online, offline = get_backends()
for backend_url in online:
backend_info = cluster_config.get_backend(backend_url)
backend_mode = backend_info['mode']
running_model, err = get_running_model(backend_url, backend_mode)
if err:
continue
average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode)
if average_generation_elapsed_sec: # returns None on exception
cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec)
if average_output_tokens:
cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens)
if average_generation_elapsed_sec and average_output_tokens:
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
default_backend_url = get_a_cluster_backend()
default_backend_info = cluster_config.get_backend(default_backend_url)
default_backend_mode = default_backend_info['mode']
default_running_model, err = get_running_model(default_backend_url, default_backend_mode)
if err:
continue
default_average_generation_elapsed_sec, default_average_output_tokens, default_estimated_avg_tps = calc_stats_for_backend(default_running_model, default_running_model, default_backend_mode)
if default_average_generation_elapsed_sec:
redis.set('average_generation_elapsed_sec', default_average_generation_elapsed_sec)
if default_average_output_tokens:
redis.set('average_output_tokens', default_average_output_tokens)
if default_average_generation_elapsed_sec and default_average_output_tokens:
redis.set('estimated_avg_tps', default_estimated_avg_tps)
time.sleep(30)
def calc_stats_for_backend(backend_url, running_model, backend_mode):
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
running_model, backend_mode, backend_url, exclude_zeros=True,
include_system_tokens=opts.include_system_tokens_in_stats) or 0
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps

View File

@ -0,0 +1,50 @@
import time
from threading import Thread
from llm_server import opts
from llm_server.cluster.stores import redis_running_models
from llm_server.cluster.worker import cluster_worker
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers
from llm_server.workers.mainer import main_background_thread
from llm_server.workers.moderator import start_moderation_workers
from llm_server.workers.printer import console_printer
from llm_server.workers.recenter import recent_prompters_thread
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(1)
def start_background():
start_workers(opts.concurrent_gens)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
print('Started the main background thread.')
start_moderation_workers(opts.openai_moderation_workers)
t = Thread(target=cache_stats)
t.daemon = True
t.start()
print('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
print('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
print('Started the console printer.')
redis_running_models.flush()
t = Thread(target=cluster_worker)
t.daemon = True
t.start()
print('Started the cluster worker.')

View File

@ -1,9 +0,0 @@
import time
from llm_server.routes.v1.generate_stats import generate_stats
def cache_stats():
while True:
generate_stats(regen=True)
time.sleep(5)

View File

@ -1,3 +1,8 @@
"""
This file is used to run certain tasks when the HTTP server starts.
It's located here so it doesn't get imported with daemon.py
"""
try:
import gevent.monkey

View File

@ -1,4 +1,4 @@
from llm_server.config.config import mode_ui_names
from llm_server.cluster.cluster_config import cluster_config
try:
import gevent.monkey
@ -7,8 +7,6 @@ try:
except ImportError:
pass
from llm_server.pre_fork import server_startup
from llm_server.config.load import load_config, parse_backends
import os
import sys
from pathlib import Path
@ -16,14 +14,17 @@ from pathlib import Path
import simplejson as json
from flask import Flask, jsonify, render_template, request
import llm_server
from llm_server.cluster.backend import get_a_cluster_backend, get_backends
from llm_server.cluster.redis_cycle import load_backend_cycle
from llm_server.config.config import mode_ui_names
from llm_server.config.load import load_config, parse_backends
from llm_server.database.conn import database
from llm_server.database.create import create_db
from llm_server.llm import get_token_count
from llm_server.pre_fork import server_startup
from llm_server.routes.openai import openai_bp
from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.stream import init_socketio
from llm_server.sock import init_socketio
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
@ -37,6 +38,8 @@ from llm_server.stream import init_socketio
# TODO: use coloredlogs
# TODO: need to update opts. for workers
# TODO: add a healthcheck to VLLM
# TODO: allow choosing the model by the URL path
# TODO: have VLLM report context size, uptime
# Lower priority
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
@ -64,7 +67,7 @@ import config
from llm_server import opts
from llm_server.helpers import auto_set_base_client_api
from llm_server.llm.vllm.info import vllm_info
from llm_server.custom_redis import RedisCustom, flask_cache
from llm_server.custom_redis import flask_cache
from llm_server.llm import redis
from llm_server.routes.stats import get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats
@ -83,20 +86,18 @@ if config_path_environ:
else:
config_path = Path(script_path, 'config', 'config.yml')
success, config, msg = load_config(config_path, script_path)
success, config, msg = load_config(config_path)
if not success:
print('Failed to load config:', msg)
sys.exit(1)
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
create_db()
llm_server.llm.redis = RedisCustom('local_llm')
create_db()
x = parse_backends(config)
print(x)
# print(app.url_map)
cluster_config.clear()
cluster_config.load(parse_backends(config))
on, off = get_backends()
load_backend_cycle('backend_cycler', on + off)
@app.route('/')
@ -104,12 +105,18 @@ print(x)
@app.route('/api/openai')
@flask_cache.cached(timeout=10)
def home():
stats = generate_stats()
# Use the default backend
backend_url = get_a_cluster_backend()
if backend_url:
backend_info = cluster_config.get_backend(backend_url)
stats = generate_stats(backend_url)
else:
backend_info = stats = None
if not stats['online']:
running_model = estimated_wait_sec = 'offline'
else:
running_model = redis.get('running_model', str, 'ERROR')
running_model = backend_info['running_model']
active_gen_workers = get_active_gen_workers()
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
@ -130,10 +137,16 @@ def home():
info_html = ''
mode_info = ''
if opts.mode == 'vllm':
using_vllm = False
for k, v in cluster_config.all().items():
if v['mode'] == vllm:
using_vllm = True
break
if using_vllm == 'vllm':
mode_info = vllm_info
base_client_api = redis.get('base_client_api', str)
base_client_api = redis.get('base_client_api', dtype=str)
return render_template('home.html',
llm_middleware_name=opts.llm_middleware_name,

View File

@ -7,23 +7,33 @@ except ImportError:
import time
from threading import Thread
from llm_server.cluster.redis_cycle import load_backend_cycle
from llm_server.cluster.funcs.backend import get_best_backends
from llm_server.cluster.redis_config_cache import RedisClusterStore
from llm_server.cluster.backend import get_backends, get_a_cluster_backend
from llm_server.cluster.worker import cluster_worker
from llm_server.config.load import parse_backends, load_config
from llm_server.cluster.redis_config_cache import RedisClusterStore
success, config, msg = load_config('./config/config.yml').resolve().absolute()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('config')
args = parser.parse_args()
success, config, msg = load_config(args.config)
cluster_config = RedisClusterStore('cluster_config')
cluster_config.clear()
cluster_config.load(parse_backends(config))
on, off = get_backends()
load_backend_cycle('backend_cycler', on + off)
t = Thread(target=cluster_worker)
t.daemon = True
t.start()
while True:
x = get_best_backends()
print(x)
# online, offline = get_backends()
# print(online, offline)
# print(get_a_cluster_backend())
time.sleep(3)