mvp
This commit is contained in:
parent
e7b57cad7b
commit
624ca74ce5
18
daemon.py
18
daemon.py
|
@ -1,22 +1,12 @@
|
|||
import time
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
gevent.monkey.patch_all()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from llm_server.config.load import load_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.create import create_db
|
||||
|
||||
from llm_server.workers.app import start_background
|
||||
from llm_server.workers.threader import start_background
|
||||
|
||||
script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
config_path_environ = os.getenv("CONFIG_PATH")
|
||||
|
@ -29,7 +19,7 @@ if __name__ == "__main__":
|
|||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
success, config, msg = load_config(config_path, script_path)
|
||||
success, config, msg = load_config(config_path)
|
||||
if not success:
|
||||
print('Failed to load config:', msg)
|
||||
sys.exit(1)
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
from llm_server.cluster.redis_config_cache import RedisClusterStore
|
||||
from llm_server.cluster.redis_cycle import redis_cycle
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.llm.info import get_running_model
|
||||
|
||||
|
||||
def test_backend(backend_url: str, mode: str):
|
||||
running_model, err = get_running_model(backend_url, mode)
|
||||
if not running_model:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_backends():
|
||||
cluster_config = RedisClusterStore('cluster_config')
|
||||
backends = cluster_config.all()
|
||||
result = {}
|
||||
for k, v in backends.items():
|
||||
b = cluster_config.get_backend(k)
|
||||
status = b['online']
|
||||
priority = b['priority']
|
||||
result[k] = {'status': status, 'priority': priority}
|
||||
online_backends = sorted(
|
||||
((url, info) for url, info in backends.items() if info['online']),
|
||||
key=lambda kv: -kv[1]['priority'],
|
||||
reverse=True
|
||||
)
|
||||
offline_backends = sorted(
|
||||
((url, info) for url, info in backends.items() if not info['online']),
|
||||
key=lambda kv: -kv[1]['priority'],
|
||||
reverse=True
|
||||
)
|
||||
return [url for url, info in online_backends], [url for url, info in offline_backends]
|
||||
|
||||
|
||||
def get_a_cluster_backend():
|
||||
"""
|
||||
Get a backend from Redis. If there are no online backends, return None.
|
||||
"""
|
||||
online, offline = get_backends()
|
||||
cycled = redis_cycle('backend_cycler')
|
||||
c = cycled.copy()
|
||||
for i in range(len(cycled)):
|
||||
if cycled[i] in offline:
|
||||
del c[c.index(cycled[i])]
|
||||
if len(c):
|
||||
return c[0]
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_backends_from_model(model_name: str):
|
||||
cluster_config = RedisClusterStore('cluster_config')
|
||||
a = cluster_config.all()
|
||||
matches = []
|
||||
for k, v in a.items():
|
||||
if v['online'] and v['running_model'] == model_name:
|
||||
matches.append(k)
|
||||
return matches
|
||||
|
||||
|
||||
def purge_backend_from_running_models(backend_url: str):
|
||||
keys = redis_running_models.keys()
|
||||
pipeline = redis_running_models.pipeline()
|
||||
for model in keys:
|
||||
pipeline.srem(model, backend_url)
|
||||
pipeline.execute()
|
||||
|
||||
|
||||
def is_valid_model(model_name: str):
|
||||
return redis_running_models.exists(model_name)
|
|
@ -0,0 +1,3 @@
|
|||
from llm_server.cluster.redis_config_cache import RedisClusterStore
|
||||
|
||||
cluster_config = RedisClusterStore('cluster_config')
|
|
@ -1,26 +0,0 @@
|
|||
from llm_server.cluster.redis_config_cache import RedisClusterStore
|
||||
from llm_server.llm.info import get_running_model
|
||||
|
||||
|
||||
def test_backend(backend_url: str):
|
||||
running_model, err = get_running_model(backend_url)
|
||||
if not running_model:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_best_backends():
|
||||
cluster_config = RedisClusterStore('cluster_config')
|
||||
backends = cluster_config.all()
|
||||
result = {}
|
||||
for k, v in backends.items():
|
||||
b = cluster_config.get_backend(k)
|
||||
status = b['online']
|
||||
priority = b['priority']
|
||||
result[k] = {'status': status, 'priority': priority}
|
||||
online_backends = sorted(
|
||||
((url, info) for url, info in backends.items() if info['online']),
|
||||
key=lambda kv: kv[1]['priority'],
|
||||
reverse=True
|
||||
)
|
||||
return [url for url, info in online_backends]
|
|
@ -1,3 +1,4 @@
|
|||
import hashlib
|
||||
import pickle
|
||||
|
||||
from llm_server.custom_redis import RedisCustom
|
||||
|
@ -13,14 +14,17 @@ class RedisClusterStore:
|
|||
|
||||
def load(self, config: dict):
|
||||
for k, v in config.items():
|
||||
self.set_backend(k, v)
|
||||
self.add_backend(k, v)
|
||||
|
||||
def set_backend(self, name: str, values: dict):
|
||||
def add_backend(self, name: str, values: dict):
|
||||
self.config_redis.hset(name, mapping={k: pickle.dumps(v) for k, v in values.items()})
|
||||
self.set_backend_value(name, 'online', False)
|
||||
h = hashlib.sha256(name.encode('utf-8')).hexdigest()
|
||||
self.set_backend_value(name, 'hash', f'{h[:8]}-{h[-8:]}')
|
||||
|
||||
def set_backend_value(self, key: str, name: str, value):
|
||||
self.config_redis.hset(key, name, pickle.dumps(value))
|
||||
def set_backend_value(self, backend: str, key: str, value):
|
||||
# By storing the value as a pickle we don't have to cast anything when getting the value from Redis.
|
||||
self.config_redis.hset(backend, key, pickle.dumps(value))
|
||||
|
||||
def get_backend(self, name: str):
|
||||
r = self.config_redis.hgetall(name)
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
import redis
|
||||
|
||||
r = redis.Redis(host='localhost', port=6379, db=9)
|
||||
|
||||
|
||||
def redis_cycle(list_name):
|
||||
while True:
|
||||
pipe = r.pipeline()
|
||||
pipe.lpop(list_name)
|
||||
popped_element = pipe.execute()[0]
|
||||
if popped_element is None:
|
||||
return None
|
||||
r.rpush(list_name, popped_element)
|
||||
new_list = r.lrange(list_name, 0, -1)
|
||||
return [x.decode('utf-8') for x in new_list]
|
||||
|
||||
|
||||
def load_backend_cycle(list_name: str, elements: list):
|
||||
r.delete(list_name)
|
||||
for element in elements:
|
||||
r.rpush(list_name, element)
|
|
@ -0,0 +1,3 @@
|
|||
from llm_server.custom_redis import RedisCustom
|
||||
|
||||
redis_running_models = RedisCustom('running_models')
|
|
@ -1,10 +1,10 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
|
||||
from llm_server.cluster.funcs.backend import test_backend
|
||||
from llm_server.cluster.redis_config_cache import RedisClusterStore
|
||||
|
||||
cluster_config = RedisClusterStore('cluster_config')
|
||||
from llm_server.cluster.backend import purge_backend_from_running_models, test_backend
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.llm.info import get_running_model
|
||||
|
||||
|
||||
def cluster_worker():
|
||||
|
@ -16,10 +16,16 @@ def cluster_worker():
|
|||
threads.append(thread)
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
def check_backend(n, v):
|
||||
# Check if backends are online
|
||||
online = test_backend(v['backend_url'])
|
||||
# TODO: also have test_backend() get the uptime
|
||||
online = test_backend(v['backend_url'], v['mode'])
|
||||
if online:
|
||||
running_model, err = get_running_model(v['backend_url'], v['mode'])
|
||||
if not err:
|
||||
cluster_config.set_backend_value(n, 'running_model', running_model)
|
||||
purge_backend_from_running_models(n)
|
||||
redis_running_models.sadd(running_model, n)
|
||||
cluster_config.set_backend_value(n, 'online', online)
|
||||
|
|
|
@ -32,7 +32,8 @@ config_default_vars = {
|
|||
'openai_org_name': 'OpenAI',
|
||||
'openai_silent_trim': False,
|
||||
'openai_moderation_enabled': True,
|
||||
'netdata_root': None
|
||||
'netdata_root': None,
|
||||
'show_backends': True,
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ def load_config(config_path):
|
|||
opts.openai_org_name = config['openai_org_name']
|
||||
opts.openai_silent_trim = config['openai_silent_trim']
|
||||
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
||||
opts.show_backends = config['show_backends']
|
||||
|
||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
import pickle
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Callable, List, Mapping, Union, Optional
|
||||
from typing import Callable, List, Mapping, Optional, Union
|
||||
|
||||
import redis as redis_pkg
|
||||
import simplejson as json
|
||||
from flask_caching import Cache
|
||||
from redis import Redis
|
||||
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, ZScoreBoundT, PatternT
|
||||
from redis.typing import AnyKeyT, EncodableT, ExpiryT, FieldT, KeyT, PatternT, ZScoreBoundT
|
||||
|
||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
|
||||
|
@ -35,12 +35,12 @@ class RedisCustom:
|
|||
def set(self, key, value, ex: Union[ExpiryT, None] = None):
|
||||
return self.redis.set(self._key(key), value, ex=ex)
|
||||
|
||||
def get(self, key, dtype=None, default=None):
|
||||
"""
|
||||
:param key:
|
||||
:param dtype: convert to this type
|
||||
:return:
|
||||
"""
|
||||
def get(self, key, default=None, dtype=None):
|
||||
# TODO: use pickle
|
||||
import inspect
|
||||
if inspect.isclass(default):
|
||||
raise Exception
|
||||
|
||||
d = self.redis.get(self._key(key))
|
||||
if dtype and d:
|
||||
try:
|
||||
|
@ -153,11 +153,23 @@ class RedisCustom:
|
|||
keys = []
|
||||
for key in raw_keys:
|
||||
p = key.decode('utf-8').split(':')
|
||||
if len(p) > 2:
|
||||
if len(p) >= 2:
|
||||
# Delete prefix
|
||||
del p[0]
|
||||
keys.append(':'.join(p))
|
||||
k = ':'.join(p)
|
||||
if k != '____':
|
||||
keys.append(k)
|
||||
return keys
|
||||
|
||||
def pipeline(self, transaction=True, shard_hint=None):
|
||||
return self.redis.pipeline(transaction, shard_hint)
|
||||
|
||||
def exists(self, *names: KeyT):
|
||||
n = []
|
||||
for name in names:
|
||||
n.append(self._key(name))
|
||||
return self.redis.exists(*n)
|
||||
|
||||
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
|
||||
return self.set(key, json.dumps(dict_value), ex=ex)
|
||||
|
||||
|
@ -174,7 +186,7 @@ class RedisCustom:
|
|||
def getp(self, name: str):
|
||||
r = self.redis.get(name)
|
||||
if r:
|
||||
return pickle.load(r)
|
||||
return pickle.loads(r)
|
||||
return r
|
||||
|
||||
def flush(self):
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
import json
|
||||
import time
|
||||
import traceback
|
||||
from threading import Thread
|
||||
|
||||
import llm_server
|
||||
from llm_server import opts
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.llm.vllm import tokenize
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, response_tokens: int = None, is_error: bool = False):
|
||||
def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens: int = None, is_error: bool = False):
|
||||
def background_task():
|
||||
nonlocal ip, token, prompt, response, gen_time, parameters, headers, backend_response_code, request_url, cluster_backend, response_tokens, is_error
|
||||
# Try not to shove JSON into the database.
|
||||
if isinstance(response, dict) and response.get('results'):
|
||||
response = response['results'][0]['text']
|
||||
try:
|
||||
|
@ -52,10 +56,15 @@ def log_prompt(ip, token, prompt, response, gen_time, parameters, headers, backe
|
|||
(ip, token, model, backend_mode, backend_url, request_url, generation_time, prompt, prompt_tokens, response, response_tokens, response_status, parameters, headers, timestamp)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
|
||||
""",
|
||||
(ip, token, running_model, opts.mode, opts.backend_url, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
(ip, token, running_model, opts.mode, cluster_backend, request_url, gen_time, prompt, prompt_tokens, response, response_tokens, backend_response_code, json.dumps(parameters), json.dumps(headers), timestamp))
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = Thread(target=background_task)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
|
||||
def is_valid_api_key(api_key):
|
||||
cursor = database.cursor()
|
||||
|
|
|
@ -60,7 +60,7 @@ def round_up_base(n, base):
|
|||
|
||||
|
||||
def auto_set_base_client_api(request):
|
||||
http_host = redis.get('http_host', str)
|
||||
http_host = redis.get('http_host', dtype=str)
|
||||
host = request.headers.get("Host")
|
||||
if http_host and not re.match(r'((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.?\b){4}', http_host):
|
||||
# If the current http_host is not an IP, don't do anything.
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
import threading
|
||||
|
||||
|
||||
class ThreadSafeInteger:
|
||||
def __init__(self, value=0):
|
||||
self.value = value
|
||||
self._value_lock = threading.Lock()
|
||||
|
||||
def increment(self):
|
||||
with self._value_lock:
|
||||
self.value += 1
|
||||
return self.value
|
|
@ -3,7 +3,7 @@ from llm_server.custom_redis import redis
|
|||
|
||||
|
||||
def get_token_count(prompt: str):
|
||||
backend_mode = redis.get('backend_mode', str)
|
||||
backend_mode = redis.get('backend_mode', dtype=str)
|
||||
if backend_mode == 'vllm':
|
||||
return vllm.tokenize(prompt)
|
||||
elif backend_mode == 'ooba':
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
from llm_server import opts
|
||||
|
||||
|
||||
def generator(request_json_body):
|
||||
def generator(request_json_body, cluster_backend):
|
||||
if opts.mode == 'oobabooga':
|
||||
# from .oobabooga.generate import generate
|
||||
# return generate(request_json_body)
|
||||
raise NotImplementedError
|
||||
elif opts.mode == 'vllm':
|
||||
from .vllm.generate import generate
|
||||
r = generate(request_json_body)
|
||||
r = generate(request_json_body, cluster_backend)
|
||||
return r
|
||||
else:
|
||||
raise Exception
|
||||
|
|
|
@ -3,19 +3,15 @@ import requests
|
|||
from llm_server import opts
|
||||
|
||||
|
||||
def get_running_model(backend_url: str):
|
||||
# TODO: remove this once we go to Redis
|
||||
if not backend_url:
|
||||
backend_url = opts.backend_url
|
||||
|
||||
if opts.mode == 'oobabooga':
|
||||
def get_running_model(backend_url: str, mode: str):
|
||||
if mode == 'ooba':
|
||||
try:
|
||||
backend_response = requests.get(f'{backend_url}/api/v1/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
return r_json['result'], None
|
||||
except Exception as e:
|
||||
return False, e
|
||||
elif opts.mode == 'vllm':
|
||||
elif mode == 'vllm':
|
||||
try:
|
||||
backend_response = requests.get(f'{backend_url}/model', timeout=opts.backend_request_timeout, verify=opts.verify_ssl)
|
||||
r_json = backend_response.json()
|
||||
|
|
|
@ -40,6 +40,6 @@ class LLMBackend:
|
|||
def validate_prompt(self, prompt: str) -> Tuple[bool, Union[str, None]]:
|
||||
prompt_len = get_token_count(prompt)
|
||||
if prompt_len > opts.context_size - 10:
|
||||
model_name = redis.get('running_model', str, 'NO MODEL ERROR')
|
||||
model_name = redis.get('running_model', 'NO MODEL ERROR', dtype=str)
|
||||
return False, f'Token indices sequence length is longer than the specified maximum sequence length for this model ({prompt_len} > {opts.context_size}, model: {model_name}). Please lower your context size'
|
||||
return True, None
|
||||
|
|
|
@ -34,7 +34,7 @@ def build_openai_response(prompt, response, model=None):
|
|||
# TODO: async/await
|
||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||
response_tokens = llm_server.llm.get_token_count(response)
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
response = make_response(jsonify({
|
||||
"id": f"chatcmpl-{generate_oai_string(30)}",
|
||||
|
@ -57,7 +57,7 @@ def build_openai_response(prompt, response, model=None):
|
|||
}
|
||||
}), 200)
|
||||
|
||||
stats = redis.get('proxy_stats', dict)
|
||||
stats = redis.get('proxy_stats', dtype=dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response
|
||||
|
|
|
@ -49,7 +49,7 @@ def transform_to_text(json_request, api_response):
|
|||
|
||||
prompt_tokens = len(llm_server.llm.get_token_count(prompt))
|
||||
completion_tokens = len(llm_server.llm.get_token_count(text))
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
# https://platform.openai.com/docs/api-reference/making-requests?lang=python
|
||||
return {
|
||||
|
@ -82,9 +82,9 @@ def transform_prompt_to_text(prompt: list):
|
|||
return text.strip('\n')
|
||||
|
||||
|
||||
def handle_blocking_request(json_data: dict):
|
||||
def handle_blocking_request(json_data: dict, cluster_backend):
|
||||
try:
|
||||
r = requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
r = requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
except requests.exceptions.ReadTimeout:
|
||||
print(f'Failed to reach VLLM inference endpoint - request to backend timed out')
|
||||
return False, None, 'Request to backend timed out'
|
||||
|
@ -97,11 +97,11 @@ def handle_blocking_request(json_data: dict):
|
|||
return True, r, None
|
||||
|
||||
|
||||
def generate(json_data: dict):
|
||||
def generate(json_data: dict, cluster_backend):
|
||||
if json_data.get('stream'):
|
||||
try:
|
||||
return requests.post(f'{opts.backend_url}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
return requests.post(f'{cluster_backend}/generate', json=prepare_json(json_data), stream=True, verify=opts.verify_ssl, timeout=opts.backend_generate_request_timeout)
|
||||
except Exception as e:
|
||||
print(f'Failed to reach VLLM inference endpoint -', f'{e.__class__.__name__}: {e}')
|
||||
else:
|
||||
return handle_blocking_request(json_data)
|
||||
return handle_blocking_request(json_data, cluster_backend)
|
||||
|
|
|
@ -19,17 +19,9 @@ class VLLMBackend(LLMBackend):
|
|||
# Failsafe
|
||||
backend_response = ''
|
||||
|
||||
r_url = request.url
|
||||
|
||||
def background_task():
|
||||
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=r_url,
|
||||
log_prompt(ip=client_ip, token=token, prompt=prompt, response=backend_response, gen_time=elapsed_time, parameters=parameters, headers=headers, backend_response_code=response_status_code, request_url=request.url,
|
||||
response_tokens=response_json_body.get('details', {}).get('generated_tokens'))
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = threading.Thread(target=background_task)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
return jsonify({'results': [{'text': backend_response}]}), 200
|
||||
|
||||
def get_parameters(self, parameters) -> Tuple[dict | None, str | None]:
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
# TODO: rewrite the config system so I don't have to add every single config default here
|
||||
|
||||
running_model = 'ERROR'
|
||||
concurrent_gens = 3
|
||||
mode = 'oobabooga'
|
||||
backend_url = None
|
||||
|
@ -38,3 +37,4 @@ openai_org_name = 'OpenAI'
|
|||
openai_silent_trim = False
|
||||
openai_moderation_enabled = True
|
||||
cluster = {}
|
||||
show_backends = True
|
||||
|
|
|
@ -7,7 +7,7 @@ from llm_server.routes.v1.generate_stats import generate_stats
|
|||
|
||||
|
||||
def server_startup(s):
|
||||
if not redis.get('daemon_started', bool):
|
||||
if not redis.get('daemon_started', dtype=bool):
|
||||
print('Could not find the key daemon_started in Redis. Did you forget to start the daemon process?')
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
def format_sillytavern_err(msg: str, level: str = 'info'):
|
||||
http_host = redis.get('http_host', str)
|
||||
def format_sillytavern_err(msg: str, backend_url: str, level: str = 'info'):
|
||||
cluster_backend_hash = cluster_config.get_backend_handler(backend_url)['hash']
|
||||
http_host = redis.get('http_host', dtype=str)
|
||||
return f"""```
|
||||
=== MESSAGE FROM LLM MIDDLEWARE AT {http_host} ===
|
||||
-> {level.upper()} <-
|
||||
{msg}
|
||||
|
||||
BACKEND HASH: {cluster_backend_hash}
|
||||
```"""
|
||||
|
|
|
@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||
backend_response = self.handle_error(msg)
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, self.cluster_backend, is_error=True)
|
||||
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
|
@ -40,7 +40,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
# TODO: how to format this
|
||||
response_msg = error_msg
|
||||
else:
|
||||
response_msg = format_sillytavern_err(error_msg, error_type)
|
||||
response_msg = format_sillytavern_err(error_msg, error_type, self.cluster_backend)
|
||||
|
||||
return jsonify({
|
||||
'results': [{'text': response_msg}]
|
||||
|
|
|
@ -5,11 +5,12 @@ import traceback
|
|||
|
||||
from flask import Response, jsonify, request
|
||||
|
||||
from . import openai_bp
|
||||
from llm_server.custom_redis import redis
|
||||
from . import openai_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
from ... import opts
|
||||
from ...cluster.backend import get_a_cluster_backend
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt
|
||||
|
@ -48,10 +49,11 @@ def openai_chat_completions():
|
|||
'stream': True,
|
||||
}
|
||||
try:
|
||||
response = generator(msg_to_backend)
|
||||
cluster_backend = get_a_cluster_backend()
|
||||
response = generator(msg_to_backend, cluster_backend)
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', str, 'ERROR') if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
def generate():
|
||||
|
@ -94,7 +96,7 @@ def openai_chat_completions():
|
|||
|
||||
def background_task():
|
||||
generated_tokens = tokenize(generated_text)
|
||||
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=generated_tokens)
|
||||
log_prompt(handler.client_ip, handler.token, handler.prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=generated_tokens)
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = threading.Thread(target=background_task)
|
||||
|
|
|
@ -29,7 +29,7 @@ def openai_completions():
|
|||
# TODO: async/await
|
||||
prompt_tokens = get_token_count(request_json_body['prompt'])
|
||||
response_tokens = get_token_count(output)
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
||||
response = make_response(jsonify({
|
||||
"id": f"cmpl-{generate_oai_string(30)}",
|
||||
|
@ -51,7 +51,7 @@ def openai_completions():
|
|||
}
|
||||
}), 200)
|
||||
|
||||
stats = redis.get('proxy_stats', dict)
|
||||
stats = redis.get('proxy_stats', dtype=dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response
|
||||
|
|
|
@ -7,6 +7,7 @@ from . import openai_bp
|
|||
from llm_server.custom_redis import ONE_MONTH_SECONDS, flask_cache, redis
|
||||
from ..stats import server_start_time
|
||||
from ... import opts
|
||||
from ...cluster.backend import get_a_cluster_backend
|
||||
from ...helpers import jsonify_pretty
|
||||
from ...llm.info import get_running_model
|
||||
|
||||
|
@ -22,7 +23,7 @@ def openai_list_models():
|
|||
'type': error.__class__.__name__
|
||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
else:
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
oai = fetch_openai_models()
|
||||
r = []
|
||||
if opts.openai_expose_our_model:
|
||||
|
|
|
@ -93,6 +93,6 @@ def incr_active_workers():
|
|||
|
||||
def decr_active_workers():
|
||||
redis.decr('active_gen_workers')
|
||||
new_count = redis.get('active_gen_workers', int, 0)
|
||||
new_count = redis.get('active_gen_workers', 0, dtype=int)
|
||||
if new_count < 0:
|
||||
redis.set('active_gen_workers', 0)
|
||||
|
|
|
@ -5,13 +5,15 @@ import flask
|
|||
from flask import Response, request
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_a_cluster_backend
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.database import log_prompt
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.oobabooga.ooba_backend import OobaboogaBackend
|
||||
from llm_server.llm.vllm.vllm_backend import VLLMBackend
|
||||
from llm_server.routes.auth import parse_token
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.routes.helpers.http import require_api_key, validate_json
|
||||
from llm_server.routes.queue import priority_queue
|
||||
|
||||
|
@ -35,7 +37,9 @@ class RequestHandler:
|
|||
self.client_ip = self.get_client_ip()
|
||||
self.token = self.get_auth_token()
|
||||
self.token_priority, self.token_simultaneous_ip = self.get_token_ratelimit()
|
||||
self.backend = get_backend()
|
||||
self.cluster_backend = get_a_cluster_backend()
|
||||
self.cluster_backend_info = cluster_config.get_backend(self.cluster_backend)
|
||||
self.backend = get_backend_handler(self.cluster_backend)
|
||||
self.parameters = None
|
||||
self.used = False
|
||||
redis.zadd('recent_prompters', {self.client_ip: time.time()})
|
||||
|
@ -119,7 +123,7 @@ class RequestHandler:
|
|||
backend_response = self.handle_error(combined_error_message, 'Validation Error')
|
||||
|
||||
if do_log:
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, is_error=True)
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), 0, self.parameters, dict(self.request.headers), 0, self.request.url, self.cluster_backend, is_error=True)
|
||||
return False, backend_response
|
||||
return True, (None, 0)
|
||||
|
||||
|
@ -131,7 +135,7 @@ class RequestHandler:
|
|||
request_valid, invalid_response = self.validate_request(prompt, do_log=True)
|
||||
if not request_valid:
|
||||
return (False, None, None, 0), invalid_response
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters), self.token_priority)
|
||||
event = priority_queue.put((llm_request, self.client_ip, self.token, self.parameters, self.cluster_backend), self.token_priority)
|
||||
else:
|
||||
event = None
|
||||
|
||||
|
@ -160,7 +164,7 @@ class RequestHandler:
|
|||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = self.handle_error(error_msg)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
# ===============================================
|
||||
|
@ -180,7 +184,7 @@ class RequestHandler:
|
|||
if return_json_err:
|
||||
error_msg = 'The backend did not return valid JSON.'
|
||||
backend_response = self.handle_error(error_msg)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, is_error=True)
|
||||
log_prompt(self.client_ip, self.token, prompt, backend_response[0].data.decode('utf-8'), elapsed_time, self.parameters, dict(self.request.headers), response_status_code, self.request.url, self.cluster_backend, is_error=True)
|
||||
return (False, None, None, 0), backend_response
|
||||
|
||||
# ===============================================
|
||||
|
@ -214,10 +218,10 @@ class RequestHandler:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_backend():
|
||||
if opts.mode == 'oobabooga':
|
||||
def get_backend_handler(mode):
|
||||
if mode == 'oobabooga':
|
||||
return OobaboogaBackend()
|
||||
elif opts.mode == 'vllm':
|
||||
elif mode == 'vllm':
|
||||
return VLLMBackend()
|
||||
else:
|
||||
raise Exception
|
||||
|
|
|
@ -2,32 +2,9 @@ from datetime import datetime
|
|||
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
# proompters_5_min = 0
|
||||
# concurrent_semaphore = Semaphore(concurrent_gens)
|
||||
|
||||
server_start_time = datetime.now()
|
||||
|
||||
|
||||
# TODO: do I need this?
|
||||
# def elapsed_times_cleanup():
|
||||
# global wait_in_queue_elapsed
|
||||
# while True:
|
||||
# current_time = time.time()
|
||||
# with wait_in_queue_elapsed_lock:
|
||||
# global wait_in_queue_elapsed
|
||||
# wait_in_queue_elapsed = [(end_time, elapsed_time) for end_time, elapsed_time in wait_in_queue_elapsed if current_time - end_time <= 60]
|
||||
# time.sleep(1)
|
||||
|
||||
|
||||
def calculate_avg_gen_time():
|
||||
# Get the average generation time from Redis
|
||||
average_generation_time = redis.get('average_generation_time')
|
||||
if average_generation_time is None:
|
||||
return 0
|
||||
else:
|
||||
return float(average_generation_time)
|
||||
|
||||
|
||||
def get_total_proompts():
|
||||
count = redis.get('proompts')
|
||||
if count is None:
|
||||
|
|
|
@ -2,11 +2,12 @@ import time
|
|||
from datetime import datetime
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_a_cluster_backend, test_backend
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import get_distinct_ips_24h, sum_column
|
||||
from llm_server.helpers import deep_sort, round_up_base
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.netdata import get_power_states
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
|
||||
|
||||
|
@ -33,52 +34,43 @@ def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, act
|
|||
return gen_time_calc
|
||||
|
||||
|
||||
# TODO: have routes/__init__.py point to the latest API version generate_stats()
|
||||
|
||||
def generate_stats(regen: bool = False):
|
||||
if not regen:
|
||||
c = redis.get('proxy_stats', dict)
|
||||
c = redis.get('proxy_stats', dtype=dict)
|
||||
if c:
|
||||
return c
|
||||
|
||||
model_name, error = get_running_model() # will return False when the fetch fails
|
||||
if isinstance(model_name, bool):
|
||||
online = False
|
||||
else:
|
||||
online = True
|
||||
redis.set('running_model', model_name)
|
||||
default_backend_url = get_a_cluster_backend()
|
||||
default_backend_info = cluster_config.get_backend(default_backend_url)
|
||||
if not default_backend_info.get('mode'):
|
||||
# TODO: remove
|
||||
print('DAEMON NOT FINISHED STARTING')
|
||||
return
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
|
||||
average_generation_elapsed_sec = redis.get('average_generation_elapsed_sec', 0)
|
||||
|
||||
# t = elapsed_times.copy() # copy since we do multiple operations and don't want it to change
|
||||
# if len(t) == 0:
|
||||
# estimated_wait = 0
|
||||
# else:
|
||||
# waits = [elapsed for end, elapsed in t]
|
||||
# estimated_wait = int(sum(waits) / len(waits))
|
||||
online = test_backend(default_backend_url, default_backend_info['mode'])
|
||||
if online:
|
||||
running_model, err = get_running_model(default_backend_url, default_backend_info['mode'])
|
||||
cluster_config.set_backend_value(default_backend_url, 'running_model', running_model)
|
||||
else:
|
||||
running_model = None
|
||||
|
||||
active_gen_workers = get_active_gen_workers()
|
||||
proompters_in_queue = len(priority_queue)
|
||||
|
||||
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
|
||||
# This is so wildly inaccurate it's disabled.
|
||||
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
|
||||
|
||||
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
|
||||
estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
|
||||
|
||||
if opts.netdata_root:
|
||||
netdata_stats = {}
|
||||
power_states = get_power_states()
|
||||
for gpu, power_state in power_states.items():
|
||||
netdata_stats[gpu] = {
|
||||
'power_state': power_state,
|
||||
# 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
|
||||
}
|
||||
else:
|
||||
netdata_stats = {}
|
||||
|
||||
base_client_api = redis.get('base_client_api', str)
|
||||
proompters_5_min = len(redis.zrangebyscore('recent_prompters', time.time() - 5 * 60, '+inf'))
|
||||
# TODO: make this for the currently selected backend
|
||||
estimated_wait_sec = calculate_wait_time(average_generation_elapsed_sec, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
|
||||
|
||||
output = {
|
||||
'default': {
|
||||
'model': running_model,
|
||||
'backend': default_backend_info['hash'],
|
||||
},
|
||||
'stats': {
|
||||
'proompters': {
|
||||
'5_min': proompters_5_min,
|
||||
|
@ -86,9 +78,10 @@ def generate_stats(regen: bool = False):
|
|||
},
|
||||
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
|
||||
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
|
||||
'average_generation_elapsed_sec': int(average_generation_time),
|
||||
'average_generation_elapsed_sec': int(average_generation_elapsed_sec),
|
||||
# 'estimated_avg_tps': estimated_avg_tps,
|
||||
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
|
||||
'num_backends': len(cluster_config.all()) if opts.show_backends else None,
|
||||
},
|
||||
'online': online,
|
||||
'endpoints': {
|
||||
|
@ -103,10 +96,7 @@ def generate_stats(regen: bool = False):
|
|||
'timestamp': int(time.time()),
|
||||
'config': {
|
||||
'gatekeeper': 'none' if opts.auth_required is False else 'token',
|
||||
'context_size': opts.context_size,
|
||||
'concurrent': opts.concurrent_gens,
|
||||
'model': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||
'mode': opts.mode,
|
||||
'simultaneous_requests_per_ip': opts.simultaneous_requests_per_ip,
|
||||
},
|
||||
'keys': {
|
||||
|
@ -114,8 +104,41 @@ def generate_stats(regen: bool = False):
|
|||
'anthropicKeys': '∞',
|
||||
},
|
||||
'backend_info': redis.get_dict('backend_info') if opts.show_backend_info else None,
|
||||
}
|
||||
|
||||
if opts.show_backends:
|
||||
for backend_url, v in cluster_config.all().items():
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
if not backend_info['online']:
|
||||
continue
|
||||
|
||||
# TODO: have this fetch the data from VLLM which will display GPU utalization
|
||||
# if opts.netdata_root:
|
||||
# netdata_stats = {}
|
||||
# power_states = get_power_states()
|
||||
# for gpu, power_state in power_states.items():
|
||||
# netdata_stats[gpu] = {
|
||||
# 'power_state': power_state,
|
||||
# # 'wh_wasted_1_hr': get_gpu_wh(int(gpu.strip('gpu')))
|
||||
# }
|
||||
# else:
|
||||
# netdata_stats = {}
|
||||
netdata_stats = {}
|
||||
|
||||
# TODO: use value returned by VLLM backend here
|
||||
# backend_uptime = int((datetime.now() - backend_info['start_time']).total_seconds()) if opts.show_uptime else None
|
||||
backend_uptime = -1
|
||||
|
||||
output['backend_info'][backend_info['hash']] = {
|
||||
'uptime': backend_uptime,
|
||||
# 'context_size': opts.context_size,
|
||||
'model': opts.manual_model_name if opts.manual_model_name else backend_info.get('running_model', 'ERROR'),
|
||||
'mode': backend_info['mode'],
|
||||
'nvidia': netdata_stats
|
||||
}
|
||||
else:
|
||||
output['backend_info'] = {}
|
||||
|
||||
result = deep_sort(output)
|
||||
|
||||
# It may take a bit to get the base client API, so don't cache until then.
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
@ -10,10 +9,11 @@ from ..helpers.http import require_api_key, validate_json
|
|||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ... import opts
|
||||
from ...cluster.backend import get_a_cluster_backend
|
||||
from ...database.database import log_prompt
|
||||
from ...llm.generator import generator
|
||||
from ...llm.vllm import tokenize
|
||||
from ...stream import sock
|
||||
from ...sock import sock
|
||||
|
||||
|
||||
# TODO: have workers process streaming requests
|
||||
|
@ -35,19 +35,13 @@ def stream(ws):
|
|||
log_in_bg(quitting_err_msg, is_error=True)
|
||||
|
||||
def log_in_bg(generated_text_bg, elapsed_time_bg: Union[int, float] = None, is_error: bool = False, status_code: int = None):
|
||||
|
||||
def background_task_exception():
|
||||
generated_tokens = tokenize(generated_text_bg)
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, response_tokens=generated_tokens, is_error=is_error)
|
||||
|
||||
# TODO: use async/await instead of threads
|
||||
thread = threading.Thread(target=background_task_exception)
|
||||
thread.start()
|
||||
thread.join()
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text_bg, elapsed_time_bg, handler.parameters, r_headers, status_code, r_url, cluster_backend, response_tokens=generated_tokens, is_error=is_error)
|
||||
|
||||
if not opts.enable_streaming:
|
||||
return 'Streaming is disabled', 401
|
||||
|
||||
cluster_backend = None
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
message_num = 0
|
||||
|
@ -90,14 +84,15 @@ def stream(ws):
|
|||
}
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None), handler.token_priority)
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, None), handler.token_priority)
|
||||
if not event:
|
||||
r, _ = handler.handle_ratelimited()
|
||||
err_msg = r.json['results'][0]['text']
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
try:
|
||||
response = generator(llm_request)
|
||||
cluster_backend = get_a_cluster_backend()
|
||||
response = generator(llm_request, cluster_backend)
|
||||
if not response:
|
||||
error_msg = 'Failed to reach backend while streaming.'
|
||||
print('Streaming failed:', error_msg)
|
||||
|
@ -142,7 +137,7 @@ def stream(ws):
|
|||
ws.close()
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
|
||||
return
|
||||
|
||||
message_num += 1
|
||||
|
@ -181,5 +176,5 @@ def stream(ws):
|
|||
# The client closed the stream.
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, response_tokens=tokenize(generated_text))
|
||||
log_prompt(handler.client_ip, handler.token, input_prompt, generated_text, elapsed_time, handler.parameters, r_headers, response_status_code, r_url, cluster_backend, response_tokens=tokenize(generated_text))
|
||||
ws.close() # this is important if we encountered and error and exited early.
|
||||
|
|
|
@ -2,22 +2,21 @@ import time
|
|||
|
||||
from flask import jsonify, request
|
||||
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from . import bp
|
||||
from ..auth import requires_auth
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from ... import opts
|
||||
from ...llm.info import get_running_model
|
||||
|
||||
|
||||
# @bp.route('/info', methods=['GET'])
|
||||
# # @cache.cached(timeout=3600, query_string=True)
|
||||
# def get_info():
|
||||
# # requests.get()
|
||||
# return 'yes'
|
||||
from ...cluster.backend import get_a_cluster_backend, get_backends, get_backends_from_model, is_valid_model
|
||||
from ...cluster.cluster_config import cluster_config
|
||||
|
||||
|
||||
@bp.route('/model', methods=['GET'])
|
||||
def get_model():
|
||||
@bp.route('/<model_name>/model', methods=['GET'])
|
||||
def get_model(model_name=None):
|
||||
if not model_name:
|
||||
b = get_a_cluster_backend()
|
||||
model_name = cluster_config.get_backend(b)['running_model']
|
||||
|
||||
# We will manage caching ourself since we don't want to cache
|
||||
# when the backend is down. Also, Cloudflare won't cache 500 errors.
|
||||
cache_key = 'model_cache::' + request.url
|
||||
|
@ -26,16 +25,17 @@ def get_model():
|
|||
if cached_response:
|
||||
return cached_response
|
||||
|
||||
model_name, error = get_running_model()
|
||||
if not model_name:
|
||||
if not is_valid_model(model_name):
|
||||
response = jsonify({
|
||||
'code': 502,
|
||||
'msg': 'failed to reach backend',
|
||||
'type': error.__class__.__name__
|
||||
}), 500 # return 500 so Cloudflare doesn't intercept us
|
||||
'code': 400,
|
||||
'msg': 'Model does not exist.',
|
||||
}), 400
|
||||
else:
|
||||
num_backends = len(get_backends_from_model(model_name))
|
||||
|
||||
response = jsonify({
|
||||
'result': opts.manual_model_name if opts.manual_model_name else model_name,
|
||||
'model_backend_count': num_backends,
|
||||
'timestamp': int(time.time())
|
||||
}), 200
|
||||
flask_cache.set(cache_key, response, timeout=60)
|
||||
|
@ -43,7 +43,11 @@ def get_model():
|
|||
return response
|
||||
|
||||
|
||||
@bp.route('/backend', methods=['GET'])
|
||||
@bp.route('/backends', methods=['GET'])
|
||||
@requires_auth
|
||||
def get_backend():
|
||||
return jsonify({'backend': opts.backend_url, 'mode': opts.mode}), 200
|
||||
online, offline = get_backends()
|
||||
result = []
|
||||
for i in online + offline:
|
||||
result.append(cluster_config.get_backend(i))
|
||||
return jsonify(result), 200
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
from threading import Thread
|
||||
|
||||
from .blocking import start_workers
|
||||
from .main import main_background_thread
|
||||
from .moderator import start_moderation_workers
|
||||
from .printer import console_printer
|
||||
from .recent import recent_prompters_thread
|
||||
from .threads import cache_stats
|
||||
from .. import opts
|
||||
|
||||
|
||||
def start_background():
|
||||
start_workers(opts.concurrent_gens)
|
||||
|
||||
t = Thread(target=main_background_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the main background thread.')
|
||||
|
||||
start_moderation_workers(opts.openai_moderation_workers)
|
||||
|
||||
t = Thread(target=cache_stats)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the stats cacher.')
|
||||
|
||||
t = Thread(target=recent_prompters_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the recent proompters thread.')
|
||||
|
||||
t = Thread(target=console_printer)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the console printer.')
|
|
@ -2,15 +2,15 @@ import threading
|
|||
import time
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.routes.queue import DataEvent, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count, priority_queue
|
||||
|
||||
|
||||
def worker():
|
||||
while True:
|
||||
need_to_wait()
|
||||
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
||||
(request_json_body, client_ip, token, parameters, cluster_backend), event_id = priority_queue.get()
|
||||
need_to_wait()
|
||||
|
||||
increment_ip_count(client_ip, 'processing_ips')
|
||||
|
@ -22,7 +22,7 @@ def worker():
|
|||
continue
|
||||
|
||||
try:
|
||||
success, response, error_msg = generator(request_json_body)
|
||||
success, response, error_msg = generator(request_json_body, cluster_backend)
|
||||
event = DataEvent(event_id)
|
||||
event.set((success, response, error_msg))
|
||||
finally:
|
||||
|
@ -42,7 +42,7 @@ def start_workers(num_workers: int):
|
|||
|
||||
def need_to_wait():
|
||||
# We need to check the number of active workers since the streaming endpoint may be doing something.
|
||||
active_workers = redis.get('active_gen_workers', int, 0)
|
||||
active_workers = redis.get('active_gen_workers', 0, dtype=int)
|
||||
s = time.time()
|
||||
while active_workers >= opts.concurrent_gens:
|
||||
time.sleep(0.01)
|
|
@ -1,55 +0,0 @@
|
|||
import time
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
from llm_server.custom_redis import redis
|
||||
|
||||
|
||||
def main_background_thread():
|
||||
redis.set('average_generation_elapsed_sec', 0)
|
||||
redis.set('estimated_avg_tps', 0)
|
||||
redis.set('average_output_tokens', 0)
|
||||
redis.set('backend_online', 0)
|
||||
redis.set_dict('backend_info', {})
|
||||
|
||||
while True:
|
||||
# TODO: unify this
|
||||
if opts.mode == 'oobabooga':
|
||||
running_model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
redis.set('running_model', running_model)
|
||||
redis.set('backend_online', 1)
|
||||
elif opts.mode == 'vllm':
|
||||
running_model, err = get_running_model()
|
||||
if err:
|
||||
print(err)
|
||||
redis.set('backend_online', 0)
|
||||
else:
|
||||
redis.set('running_model', running_model)
|
||||
redis.set('backend_online', 1)
|
||||
else:
|
||||
raise Exception
|
||||
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
if average_generation_elapsed_sec: # returns None on exception
|
||||
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
||||
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
||||
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
if average_generation_elapsed_sec:
|
||||
redis.set('average_output_tokens', average_output_tokens)
|
||||
|
||||
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
|
||||
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
|
||||
|
||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||
redis.set('estimated_avg_tps', estimated_avg_tps)
|
||||
time.sleep(60)
|
|
@ -0,0 +1,56 @@
|
|||
import time
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.backend import get_a_cluster_backend, get_backends
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_running_model
|
||||
|
||||
|
||||
def main_background_thread():
|
||||
while True:
|
||||
online, offline = get_backends()
|
||||
for backend_url in online:
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
backend_mode = backend_info['mode']
|
||||
running_model, err = get_running_model(backend_url, backend_mode)
|
||||
if err:
|
||||
continue
|
||||
|
||||
average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps = calc_stats_for_backend(backend_url, running_model, backend_mode)
|
||||
if average_generation_elapsed_sec: # returns None on exception
|
||||
cluster_config.set_backend_value(backend_url, 'average_generation_elapsed_sec', average_generation_elapsed_sec)
|
||||
if average_output_tokens:
|
||||
cluster_config.set_backend_value(backend_url, 'average_output_tokens', average_output_tokens)
|
||||
if average_generation_elapsed_sec and average_output_tokens:
|
||||
cluster_config.set_backend_value(backend_url, 'estimated_avg_tps', estimated_avg_tps)
|
||||
|
||||
default_backend_url = get_a_cluster_backend()
|
||||
default_backend_info = cluster_config.get_backend(default_backend_url)
|
||||
default_backend_mode = default_backend_info['mode']
|
||||
default_running_model, err = get_running_model(default_backend_url, default_backend_mode)
|
||||
if err:
|
||||
continue
|
||||
|
||||
default_average_generation_elapsed_sec, default_average_output_tokens, default_estimated_avg_tps = calc_stats_for_backend(default_running_model, default_running_model, default_backend_mode)
|
||||
if default_average_generation_elapsed_sec:
|
||||
redis.set('average_generation_elapsed_sec', default_average_generation_elapsed_sec)
|
||||
if default_average_output_tokens:
|
||||
redis.set('average_output_tokens', default_average_output_tokens)
|
||||
if default_average_generation_elapsed_sec and default_average_output_tokens:
|
||||
redis.set('estimated_avg_tps', default_estimated_avg_tps)
|
||||
time.sleep(30)
|
||||
|
||||
|
||||
def calc_stats_for_backend(backend_url, running_model, backend_mode):
|
||||
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
||||
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
||||
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time',
|
||||
running_model, backend_mode, backend_url, exclude_zeros=True,
|
||||
include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens',
|
||||
running_model, backend_mode, backend_url, exclude_zeros=True,
|
||||
include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
||||
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
||||
return average_generation_elapsed_sec, average_output_tokens, estimated_avg_tps
|
|
@ -0,0 +1,50 @@
|
|||
import time
|
||||
from threading import Thread
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.cluster.worker import cluster_worker
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.inferencer import start_workers
|
||||
from llm_server.workers.mainer import main_background_thread
|
||||
from llm_server.workers.moderator import start_moderation_workers
|
||||
from llm_server.workers.printer import console_printer
|
||||
from llm_server.workers.recenter import recent_prompters_thread
|
||||
|
||||
|
||||
def cache_stats():
|
||||
while True:
|
||||
generate_stats(regen=True)
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def start_background():
|
||||
start_workers(opts.concurrent_gens)
|
||||
|
||||
t = Thread(target=main_background_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the main background thread.')
|
||||
|
||||
start_moderation_workers(opts.openai_moderation_workers)
|
||||
|
||||
t = Thread(target=cache_stats)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the stats cacher.')
|
||||
|
||||
t = Thread(target=recent_prompters_thread)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the recent proompters thread.')
|
||||
|
||||
t = Thread(target=console_printer)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the console printer.')
|
||||
|
||||
redis_running_models.flush()
|
||||
t = Thread(target=cluster_worker)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
print('Started the cluster worker.')
|
|
@ -1,9 +0,0 @@
|
|||
import time
|
||||
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
||||
|
||||
def cache_stats():
|
||||
while True:
|
||||
generate_stats(regen=True)
|
||||
time.sleep(5)
|
|
@ -1,3 +1,8 @@
|
|||
"""
|
||||
This file is used to run certain tasks when the HTTP server starts.
|
||||
It's located here so it doesn't get imported with daemon.py
|
||||
"""
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
49
server.py
49
server.py
|
@ -1,4 +1,4 @@
|
|||
from llm_server.config.config import mode_ui_names
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
@ -7,8 +7,6 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
from llm_server.pre_fork import server_startup
|
||||
from llm_server.config.load import load_config, parse_backends
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
@ -16,14 +14,17 @@ from pathlib import Path
|
|||
import simplejson as json
|
||||
from flask import Flask, jsonify, render_template, request
|
||||
|
||||
import llm_server
|
||||
from llm_server.cluster.backend import get_a_cluster_backend, get_backends
|
||||
from llm_server.cluster.redis_cycle import load_backend_cycle
|
||||
from llm_server.config.config import mode_ui_names
|
||||
from llm_server.config.load import load_config, parse_backends
|
||||
from llm_server.database.conn import database
|
||||
from llm_server.database.create import create_db
|
||||
from llm_server.llm import get_token_count
|
||||
from llm_server.pre_fork import server_startup
|
||||
from llm_server.routes.openai import openai_bp
|
||||
from llm_server.routes.server_error import handle_server_error
|
||||
from llm_server.routes.v1 import bp
|
||||
from llm_server.stream import init_socketio
|
||||
from llm_server.sock import init_socketio
|
||||
|
||||
# TODO: set VLLM to stream ALL data using socket.io. If the socket disconnects, cancel generation.
|
||||
# TODO: add backend fallbacks. Backends at the bottom of the list are higher priority and are fallbacks if the upper ones fail
|
||||
|
@ -37,6 +38,8 @@ from llm_server.stream import init_socketio
|
|||
# TODO: use coloredlogs
|
||||
# TODO: need to update opts. for workers
|
||||
# TODO: add a healthcheck to VLLM
|
||||
# TODO: allow choosing the model by the URL path
|
||||
# TODO: have VLLM report context size, uptime
|
||||
|
||||
# Lower priority
|
||||
# TODO: estiamted wait time needs to account for full concurrent_gens but the queue is less than concurrent_gens
|
||||
|
@ -64,7 +67,7 @@ import config
|
|||
from llm_server import opts
|
||||
from llm_server.helpers import auto_set_base_client_api
|
||||
from llm_server.llm.vllm.info import vllm_info
|
||||
from llm_server.custom_redis import RedisCustom, flask_cache
|
||||
from llm_server.custom_redis import flask_cache
|
||||
from llm_server.llm import redis
|
||||
from llm_server.routes.stats import get_active_gen_workers
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
|
@ -83,20 +86,18 @@ if config_path_environ:
|
|||
else:
|
||||
config_path = Path(script_path, 'config', 'config.yml')
|
||||
|
||||
success, config, msg = load_config(config_path, script_path)
|
||||
success, config, msg = load_config(config_path)
|
||||
if not success:
|
||||
print('Failed to load config:', msg)
|
||||
sys.exit(1)
|
||||
|
||||
database.init_db(config['mysql']['host'], config['mysql']['username'], config['mysql']['password'], config['mysql']['database'])
|
||||
create_db()
|
||||
llm_server.llm.redis = RedisCustom('local_llm')
|
||||
create_db()
|
||||
|
||||
x = parse_backends(config)
|
||||
print(x)
|
||||
|
||||
# print(app.url_map)
|
||||
cluster_config.clear()
|
||||
cluster_config.load(parse_backends(config))
|
||||
on, off = get_backends()
|
||||
load_backend_cycle('backend_cycler', on + off)
|
||||
|
||||
|
||||
@app.route('/')
|
||||
|
@ -104,12 +105,18 @@ print(x)
|
|||
@app.route('/api/openai')
|
||||
@flask_cache.cached(timeout=10)
|
||||
def home():
|
||||
stats = generate_stats()
|
||||
# Use the default backend
|
||||
backend_url = get_a_cluster_backend()
|
||||
if backend_url:
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
stats = generate_stats(backend_url)
|
||||
else:
|
||||
backend_info = stats = None
|
||||
|
||||
if not stats['online']:
|
||||
running_model = estimated_wait_sec = 'offline'
|
||||
else:
|
||||
running_model = redis.get('running_model', str, 'ERROR')
|
||||
running_model = backend_info['running_model']
|
||||
|
||||
active_gen_workers = get_active_gen_workers()
|
||||
if stats['queue']['queued'] == 0 and active_gen_workers >= opts.concurrent_gens:
|
||||
|
@ -130,10 +137,16 @@ def home():
|
|||
info_html = ''
|
||||
|
||||
mode_info = ''
|
||||
if opts.mode == 'vllm':
|
||||
using_vllm = False
|
||||
for k, v in cluster_config.all().items():
|
||||
if v['mode'] == vllm:
|
||||
using_vllm = True
|
||||
break
|
||||
|
||||
if using_vllm == 'vllm':
|
||||
mode_info = vllm_info
|
||||
|
||||
base_client_api = redis.get('base_client_api', str)
|
||||
base_client_api = redis.get('base_client_api', dtype=str)
|
||||
|
||||
return render_template('home.html',
|
||||
llm_middleware_name=opts.llm_middleware_name,
|
||||
|
|
|
@ -7,23 +7,33 @@ except ImportError:
|
|||
|
||||
import time
|
||||
from threading import Thread
|
||||
from llm_server.cluster.redis_cycle import load_backend_cycle
|
||||
|
||||
from llm_server.cluster.funcs.backend import get_best_backends
|
||||
from llm_server.cluster.redis_config_cache import RedisClusterStore
|
||||
from llm_server.cluster.backend import get_backends, get_a_cluster_backend
|
||||
from llm_server.cluster.worker import cluster_worker
|
||||
from llm_server.config.load import parse_backends, load_config
|
||||
from llm_server.cluster.redis_config_cache import RedisClusterStore
|
||||
|
||||
success, config, msg = load_config('./config/config.yml').resolve().absolute()
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('config')
|
||||
args = parser.parse_args()
|
||||
|
||||
success, config, msg = load_config(args.config)
|
||||
|
||||
cluster_config = RedisClusterStore('cluster_config')
|
||||
cluster_config.clear()
|
||||
cluster_config.load(parse_backends(config))
|
||||
on, off = get_backends()
|
||||
load_backend_cycle('backend_cycler', on + off)
|
||||
|
||||
t = Thread(target=cluster_worker)
|
||||
t.daemon = True
|
||||
t.start()
|
||||
|
||||
while True:
|
||||
x = get_best_backends()
|
||||
print(x)
|
||||
# online, offline = get_backends()
|
||||
# print(online, offline)
|
||||
# print(get_a_cluster_backend())
|
||||
time.sleep(3)
|
||||
|
|
Reference in New Issue