Merge cluster to master #3
22
daemon.py
22
daemon.py
|
@ -1,4 +1,5 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
|
@ -10,6 +11,7 @@ from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.config.load import load_config, parse_backends
|
from llm_server.config.load import load_config, parse_backends
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from llm_server.database.create import create_db
|
from llm_server.database.create import create_db
|
||||||
|
from llm_server.logging import create_logger, logging_info, init_logging
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.workers.threader import start_background
|
from llm_server.workers.threader import start_background
|
||||||
|
|
||||||
|
@ -21,18 +23,26 @@ else:
|
||||||
config_path = Path(script_path, 'config', 'config.yml')
|
config_path = Path(script_path, 'config', 'config.yml')
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description='Daemon microservice.')
|
||||||
description='Daemon microservice.')
|
|
||||||
parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.")
|
parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.")
|
||||||
|
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# TODO: have this be set by either the arg or a config value
|
||||||
|
if args.debug:
|
||||||
|
logging_info.level = logging.DEBUG
|
||||||
|
|
||||||
|
init_logging()
|
||||||
|
logger = create_logger('daemon')
|
||||||
|
logger.debug('Debug logging enabled.')
|
||||||
|
|
||||||
if not args.no_reset:
|
if not args.no_reset:
|
||||||
Redis().flushall()
|
Redis().flushall()
|
||||||
print('Flushed Redis.')
|
logger.info('Flushed Redis.')
|
||||||
|
|
||||||
success, config, msg = load_config(config_path)
|
success, config, msg = load_config(config_path)
|
||||||
if not success:
|
if not success:
|
||||||
print('Failed to load config:', msg)
|
logger.info(f'Failed to load config: {msg}')
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
create_db()
|
create_db()
|
||||||
|
@ -40,13 +50,13 @@ if __name__ == "__main__":
|
||||||
cluster_config.clear()
|
cluster_config.clear()
|
||||||
cluster_config.load(parse_backends(config))
|
cluster_config.load(parse_backends(config))
|
||||||
|
|
||||||
print('Loading backend stats...')
|
logger.info('Loading backend stats...')
|
||||||
generate_stats()
|
generate_stats()
|
||||||
|
|
||||||
start_background()
|
start_background()
|
||||||
|
|
||||||
redis.set('daemon_started', 1)
|
redis.set('daemon_started', 1)
|
||||||
print('== Daemon Setup Complete ==\n')
|
logger.info('== Daemon Setup Complete ==')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import time
|
import time
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
|
||||||
from llm_server.cluster.backend import test_backend
|
from llm_server.cluster.backend import test_backend
|
||||||
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.cluster.stores import redis_running_models
|
from llm_server.cluster.stores import redis_running_models
|
||||||
|
|
||||||
|
|
||||||
|
@ -26,7 +26,6 @@ def cluster_worker():
|
||||||
|
|
||||||
def check_backend(n, v, test_prompt):
|
def check_backend(n, v, test_prompt):
|
||||||
online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
|
online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
|
||||||
# purge_backend_from_running_models(n)
|
|
||||||
if online:
|
if online:
|
||||||
running_model = backend_info['model']
|
running_model = backend_info['model']
|
||||||
for k, v in backend_info.items():
|
for k, v in backend_info.items():
|
||||||
|
@ -36,7 +35,4 @@ def check_backend(n, v, test_prompt):
|
||||||
for model in redis_running_models.keys():
|
for model in redis_running_models.keys():
|
||||||
redis_running_models.srem(model, n)
|
redis_running_models.srem(model, n)
|
||||||
|
|
||||||
# redis_running_models.srem(backend_info['model'], n)
|
|
||||||
# backend_cycler_store.lrem(backend_info['model'], 1, n)
|
|
||||||
|
|
||||||
cluster_config.set_backend_value(n, 'online', online)
|
cluster_config.set_backend_value(n, 'online', online)
|
||||||
|
|
|
@ -1,10 +1,28 @@
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.llm import oobabooga, vllm
|
from llm_server.llm import oobabooga, vllm
|
||||||
|
from llm_server.logging import create_logger
|
||||||
|
|
||||||
|
|
||||||
|
def fallback_tokenizer(prompt: str):
|
||||||
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
return len(tokenizer.encode(prompt)) + 10
|
||||||
|
|
||||||
|
|
||||||
def get_token_count(prompt: str, backend_url: str):
|
def get_token_count(prompt: str, backend_url: str):
|
||||||
backend_url = cluster_config.validate_backend(backend_url)
|
backend_url = cluster_config.validate_backend(backend_url)
|
||||||
backend_mode = cluster_config.get_backend(backend_url)['mode']
|
if not backend_url:
|
||||||
|
logger = create_logger('tokenizer')
|
||||||
|
logger.warning('using fallback tokenizer as there is no valid backend')
|
||||||
|
return fallback_tokenizer(prompt)
|
||||||
|
|
||||||
|
backend_mode = cluster_config.get_backend(backend_url).get('mode')
|
||||||
|
if not backend_mode:
|
||||||
|
logger = create_logger('tokenizer')
|
||||||
|
logger.warning("using fallback tokenizer as the backend isn't initalized")
|
||||||
|
return fallback_tokenizer(prompt)
|
||||||
|
|
||||||
if backend_mode == 'vllm':
|
if backend_mode == 'vllm':
|
||||||
return vllm.tokenize(prompt, backend_url)
|
return vllm.tokenize(prompt, backend_url)
|
||||||
elif backend_mode == 'ooba':
|
elif backend_mode == 'ooba':
|
||||||
|
|
|
@ -5,6 +5,7 @@ import tiktoken
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
|
from llm_server.logging import create_logger
|
||||||
|
|
||||||
|
|
||||||
def tokenize(prompt: str, backend_url: str) -> int:
|
def tokenize(prompt: str, backend_url: str) -> int:
|
||||||
|
@ -16,6 +17,8 @@ def tokenize(prompt: str, backend_url: str) -> int:
|
||||||
return 0
|
return 0
|
||||||
assert isinstance(prompt, str)
|
assert isinstance(prompt, str)
|
||||||
|
|
||||||
|
logger = create_logger('tokenizer')
|
||||||
|
|
||||||
# The backend could have died between when the request was
|
# The backend could have died between when the request was
|
||||||
# submitted and now, so let's double check it's still online.
|
# submitted and now, so let's double check it's still online.
|
||||||
backend_url = cluster_config.validate_backend(backend_url)
|
backend_url = cluster_config.validate_backend(backend_url)
|
||||||
|
@ -33,7 +36,7 @@ def tokenize(prompt: str, backend_url: str) -> int:
|
||||||
j = r.json()
|
j = r.json()
|
||||||
return j['length']
|
return j['length']
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
|
logger.debug(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
|
||||||
return len(tokenizer.encode(chunk)) + 10
|
return len(tokenizer.encode(chunk)) + 10
|
||||||
|
|
||||||
# Use a ThreadPoolExecutor to send all chunks to the server at once
|
# Use a ThreadPoolExecutor to send all chunks to the server at once
|
||||||
|
@ -44,5 +47,5 @@ def tokenize(prompt: str, backend_url: str) -> int:
|
||||||
try:
|
try:
|
||||||
data = future.result()
|
data = future.result()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
print('%r generated an exception: %s' % (chunk, exc))
|
logger.warning('%r generated an exception: %s' % (chunk, exc))
|
||||||
return sum(future.result() for future in future_to_chunk)
|
return sum(future.result() for future in future_to_chunk)
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import coloredlogs
|
||||||
|
|
||||||
|
from llm_server import opts
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingInfo:
|
||||||
|
def __init__(self):
|
||||||
|
self._level = logging.INFO
|
||||||
|
self._format = opts.LOGGING_FORMAT
|
||||||
|
|
||||||
|
@property
|
||||||
|
def level(self):
|
||||||
|
return self._level
|
||||||
|
|
||||||
|
@level.setter
|
||||||
|
def level(self, value):
|
||||||
|
self._level = value
|
||||||
|
|
||||||
|
@property
|
||||||
|
def format(self):
|
||||||
|
return self._format
|
||||||
|
|
||||||
|
@format.setter
|
||||||
|
def format(self, value):
|
||||||
|
self._format = value
|
||||||
|
|
||||||
|
|
||||||
|
logging_info = LoggingInfo()
|
||||||
|
|
||||||
|
|
||||||
|
def init_logging():
|
||||||
|
"""
|
||||||
|
Set up the parent logger.
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
logger = logging.getLogger('llm_server')
|
||||||
|
logger.setLevel(logging_info.level)
|
||||||
|
|
||||||
|
|
||||||
|
def create_logger(name):
|
||||||
|
logger = logging.getLogger('llm_server').getChild(name)
|
||||||
|
logger.setLevel(logging_info.level)
|
||||||
|
if not logger.handlers:
|
||||||
|
handler = logging.StreamHandler()
|
||||||
|
handler.setLevel(logging_info.level)
|
||||||
|
formatter = logging.Formatter(logging_info.format)
|
||||||
|
handler.setFormatter(formatter)
|
||||||
|
logger.addHandler(handler)
|
||||||
|
coloredlogs.install(logger=logger, level=logging_info.level)
|
||||||
|
return logger
|
|
@ -1,5 +1,8 @@
|
||||||
# Read-only global variables
|
# Read-only global variables
|
||||||
|
|
||||||
|
# Uppercase variables are read-only globals.
|
||||||
|
# Lowercase variables are ones that are set on startup and are never changed.
|
||||||
|
|
||||||
# TODO: rewrite the config system so I don't have to add every single config default here
|
# TODO: rewrite the config system so I don't have to add every single config default here
|
||||||
|
|
||||||
frontend_api_mode = 'ooba'
|
frontend_api_mode = 'ooba'
|
||||||
|
@ -39,3 +42,5 @@ openai_moderation_timeout = 5
|
||||||
prioritize_by_size = False
|
prioritize_by_size = False
|
||||||
cluster_workers = 0
|
cluster_workers = 0
|
||||||
redis_stream_timeout = 25000
|
redis_stream_timeout = 25000
|
||||||
|
|
||||||
|
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"
|
||||||
|
|
|
@ -16,7 +16,7 @@ class OobaRequestHandler(RequestHandler):
|
||||||
def handle_request(self, return_ok: bool = True):
|
def handle_request(self, return_ok: bool = True):
|
||||||
assert not self.used
|
assert not self.used
|
||||||
if self.offline:
|
if self.offline:
|
||||||
print(messages.BACKEND_OFFLINE)
|
print('This backend is offline:', messages.BACKEND_OFFLINE)
|
||||||
return self.handle_error(messages.BACKEND_OFFLINE)
|
return self.handle_error(messages.BACKEND_OFFLINE)
|
||||||
|
|
||||||
request_valid, invalid_response = self.validate_request()
|
request_valid, invalid_response = self.validate_request()
|
||||||
|
|
|
@ -24,7 +24,7 @@ def handle_error(e):
|
||||||
"auth_subrequest_error"
|
"auth_subrequest_error"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print(e)
|
print('OAI returning error:', e)
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"error": {
|
"error": {
|
||||||
"message": "Internal server error",
|
"message": "Internal server error",
|
||||||
|
|
|
@ -29,9 +29,7 @@ def openai_chat_completions(model_name=None):
|
||||||
else:
|
else:
|
||||||
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||||
if handler.offline:
|
if handler.offline:
|
||||||
msg = return_invalid_model_err(model_name)
|
return return_invalid_model_err(model_name)
|
||||||
print(msg)
|
|
||||||
return handler.handle_error(msg)
|
|
||||||
|
|
||||||
if not request_json_body.get('stream'):
|
if not request_json_body.get('stream'):
|
||||||
try:
|
try:
|
||||||
|
@ -100,7 +98,8 @@ def openai_chat_completions(model_name=None):
|
||||||
# Need to do this before we enter generate() since we want to be able to
|
# Need to do this before we enter generate() since we want to be able to
|
||||||
# return a 408 if necessary.
|
# return a 408 if necessary.
|
||||||
_, stream_name, error_msg = event.wait()
|
_, stream_name, error_msg = event.wait()
|
||||||
if error_msg == 'closed':
|
if error_msg:
|
||||||
|
print('OAI failed to start streaming:', error_msg)
|
||||||
stream_name = None # set to null so that the Finally ignores it.
|
stream_name = None # set to null so that the Finally ignores it.
|
||||||
return 'Request Timeout', 408
|
return 'Request Timeout', 408
|
||||||
|
|
||||||
|
@ -120,7 +119,8 @@ def openai_chat_completions(model_name=None):
|
||||||
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
||||||
data = ujson.loads(item[b'data'])
|
data = ujson.loads(item[b'data'])
|
||||||
if data['error']:
|
if data['error']:
|
||||||
print('OAI streaming error:', data['error'])
|
# Not printing error since we can just check the daemon log.
|
||||||
|
print('OAI streaming encountered error')
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
return
|
return
|
||||||
elif data['new']:
|
elif data['new']:
|
||||||
|
|
|
@ -29,9 +29,7 @@ def openai_completions(model_name=None):
|
||||||
else:
|
else:
|
||||||
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
|
||||||
if handler.offline:
|
if handler.offline:
|
||||||
msg = return_invalid_model_err(model_name)
|
return return_invalid_model_err(model_name)
|
||||||
print(msg)
|
|
||||||
return handler.handle_error(msg)
|
|
||||||
|
|
||||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||||
# TODO: implement other backends
|
# TODO: implement other backends
|
||||||
|
@ -145,7 +143,8 @@ def openai_completions(model_name=None):
|
||||||
oai_string = generate_oai_string(30)
|
oai_string = generate_oai_string(30)
|
||||||
|
|
||||||
_, stream_name, error_msg = event.wait()
|
_, stream_name, error_msg = event.wait()
|
||||||
if error_msg == 'closed':
|
if error_msg:
|
||||||
|
print('OAI failed to start streaming:', error_msg)
|
||||||
stream_name = None
|
stream_name = None
|
||||||
return 'Request Timeout', 408
|
return 'Request Timeout', 408
|
||||||
|
|
||||||
|
@ -165,7 +164,7 @@ def openai_completions(model_name=None):
|
||||||
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
||||||
data = ujson.loads(item[b'data'])
|
data = ujson.loads(item[b'data'])
|
||||||
if data['error']:
|
if data['error']:
|
||||||
print('OAI streaming error:', data['error'])
|
print('OAI streaming encountered error')
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
return
|
return
|
||||||
elif data['new']:
|
elif data['new']:
|
||||||
|
|
|
@ -29,7 +29,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
assert not self.used
|
assert not self.used
|
||||||
if self.offline:
|
if self.offline:
|
||||||
msg = return_invalid_model_err(self.selected_model)
|
msg = return_invalid_model_err(self.selected_model)
|
||||||
print(msg)
|
print('OAI Offline:', msg)
|
||||||
return self.handle_error(msg)
|
return self.handle_error(msg)
|
||||||
|
|
||||||
if opts.openai_silent_trim:
|
if opts.openai_silent_trim:
|
||||||
|
@ -106,7 +106,7 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
return response, 429
|
return response, 429
|
||||||
|
|
||||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||||
print(error_msg)
|
print('OAI Error:', error_msg)
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"error": {
|
"error": {
|
||||||
"message": "Invalid request, check your parameters and try again.",
|
"message": "Invalid request, check your parameters and try again.",
|
||||||
|
|
|
@ -37,8 +37,6 @@ class RedisPriorityQueue:
|
||||||
assert priority is not None
|
assert priority is not None
|
||||||
assert selected_model is not None
|
assert selected_model is not None
|
||||||
|
|
||||||
event = DataEvent()
|
|
||||||
|
|
||||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||||
ip_count = self.get_ip_request_count(item[1])
|
ip_count = self.get_ip_request_count(item[1])
|
||||||
_, simultaneous_ip = get_token_ratelimit(item[2])
|
_, simultaneous_ip = get_token_ratelimit(item[2])
|
||||||
|
@ -47,6 +45,7 @@ class RedisPriorityQueue:
|
||||||
return None # reject the request
|
return None # reject the request
|
||||||
|
|
||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
|
event = DataEvent()
|
||||||
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
|
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@ -107,7 +106,7 @@ class DataEvent:
|
||||||
Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol).
|
Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, event_id=None):
|
def __init__(self, event_id: str = None):
|
||||||
self.event_id = event_id if event_id else str(uuid4())
|
self.event_id = event_id if event_id else str(uuid4())
|
||||||
self.redis = Redis(host='localhost', port=6379, db=14)
|
self.redis = Redis(host='localhost', port=6379, db=14)
|
||||||
self.pubsub = self.redis.pubsub()
|
self.pubsub = self.redis.pubsub()
|
||||||
|
@ -118,7 +117,6 @@ class DataEvent:
|
||||||
|
|
||||||
def wait(self):
|
def wait(self):
|
||||||
for item in self.pubsub.listen():
|
for item in self.pubsub.listen():
|
||||||
print(item)
|
|
||||||
if item['type'] == 'message':
|
if item['type'] == 'message':
|
||||||
return pickle.loads(item['data'])
|
return pickle.loads(item['data'])
|
||||||
|
|
||||||
|
|
|
@ -44,12 +44,13 @@ class RequestHandler:
|
||||||
self.backend_url = get_a_cluster_backend(selected_model)
|
self.backend_url = get_a_cluster_backend(selected_model)
|
||||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||||
|
|
||||||
if not self.cluster_backend_info.get('mode'):
|
# Debug stuff
|
||||||
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
|
# if not self.cluster_backend_info.get('mode'):
|
||||||
if not self.cluster_backend_info.get('model'):
|
# print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||||
print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
|
# if not self.cluster_backend_info.get('model'):
|
||||||
if not self.cluster_backend_info.get('model_config'):
|
# print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||||
print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
|
# if not self.cluster_backend_info.get('model_config'):
|
||||||
|
# print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
|
||||||
|
|
||||||
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
|
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
|
||||||
self.offline = True
|
self.offline = True
|
||||||
|
|
|
@ -1,3 +1,3 @@
|
||||||
def handle_server_error(e):
|
def handle_server_error(e):
|
||||||
print(e)
|
print('Internal Error:', e)
|
||||||
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500
|
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500
|
||||||
|
|
|
@ -130,7 +130,8 @@ def do_stream(ws, model_name):
|
||||||
event_id = event.event_id
|
event_id = event.event_id
|
||||||
|
|
||||||
_, stream_name, error_msg = event.wait()
|
_, stream_name, error_msg = event.wait()
|
||||||
if error_msg == 'closed':
|
if error_msg:
|
||||||
|
print('Stream failed to start streaming:', error_msg)
|
||||||
ws.close(reason=1014, message='Request Timeout')
|
ws.close(reason=1014, message='Request Timeout')
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ from redis import Redis
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.custom_redis import RedisCustom, redis
|
from llm_server.custom_redis import RedisCustom, redis
|
||||||
from llm_server.llm.generator import generator
|
from llm_server.llm.generator import generator
|
||||||
|
from llm_server.logging import create_logger
|
||||||
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
||||||
|
|
||||||
stream_redis = Redis(db=8)
|
stream_redis = Redis(db=8)
|
||||||
|
@ -39,6 +40,7 @@ def get_stream_name(name: str):
|
||||||
|
|
||||||
|
|
||||||
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
|
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
|
||||||
|
logger = create_logger('inferencer')
|
||||||
prompt = msg_to_backend['prompt']
|
prompt = msg_to_backend['prompt']
|
||||||
stream_name = get_stream_name(stream_name)
|
stream_name = get_stream_name(stream_name)
|
||||||
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
|
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
|
||||||
|
@ -53,7 +55,7 @@ def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str
|
||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
if event.is_set():
|
if event.is_set():
|
||||||
print('Client canceled generation')
|
logger.debug('Client canceled generation')
|
||||||
response.close()
|
response.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -70,40 +72,60 @@ def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str
|
||||||
# ????
|
# ????
|
||||||
continue
|
continue
|
||||||
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})})
|
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})})
|
||||||
|
except AttributeError as e:
|
||||||
|
if str(e) == "'bool' object has no attribute 'iter_content'":
|
||||||
|
# We don't care about these errors.
|
||||||
|
logger.debug('failed to stream from backend - no response')
|
||||||
|
else:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
|
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
|
||||||
traceback.print_exc()
|
raise # We won't handle the exception here.
|
||||||
finally:
|
finally:
|
||||||
# Publish final message to Redis stream
|
# Publish final message to Redis stream
|
||||||
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})})
|
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})})
|
||||||
event.set() # stop the cancellation checking thread
|
event.set() # stop the cancellation checking thread
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
def worker(backend_url):
|
def worker(backend_url):
|
||||||
|
logger = create_logger('inferencer')
|
||||||
status_redis = RedisCustom('worker_status')
|
status_redis = RedisCustom('worker_status')
|
||||||
worker_id = str(uuid4())
|
worker_id = str(uuid4())
|
||||||
status_redis.setp(str(worker_id), None)
|
status_redis.setp(str(worker_id), None)
|
||||||
redis_queue = RedisPriorityQueue(backend_url)
|
redis_queue = RedisPriorityQueue(backend_url)
|
||||||
while True:
|
while True:
|
||||||
|
status_redis.setp(str(worker_id), 'waiting...')
|
||||||
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
|
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
|
||||||
|
event = DataEvent(event_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
backend_info = cluster_config.get_backend(backend_url)
|
backend_info = cluster_config.get_backend(backend_url)
|
||||||
|
except:
|
||||||
|
# This is not a critical error because it usually means that the backend is
|
||||||
|
# offline and this backend is in a state of transition from online to offline.
|
||||||
|
logger.debug(f'got an exception while getting info for backend {backend_url} - ', traceback.format_exc())
|
||||||
|
event.set((False, None, 'exception'))
|
||||||
|
continue
|
||||||
|
|
||||||
if not backend_info['online']:
|
if not backend_info['online']:
|
||||||
redis.publish(event_id, 'canceled')
|
event.set((False, None, 'canceled'))
|
||||||
return
|
continue
|
||||||
|
|
||||||
if not selected_model:
|
if not selected_model:
|
||||||
selected_model = backend_info['model']
|
selected_model = backend_info['model']
|
||||||
|
|
||||||
|
logger.debug(f"Starting using {backend_url} and {selected_model}. Online: {backend_info['online']}")
|
||||||
|
|
||||||
|
try:
|
||||||
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
|
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
|
||||||
increment_ip_count(client_ip, 'processing_ips')
|
increment_ip_count(client_ip, 'processing_ips')
|
||||||
incr_active_workers(selected_model, backend_url)
|
incr_active_workers(selected_model, backend_url)
|
||||||
status_redis.setp(str(worker_id), ('generating', client_ip))
|
|
||||||
|
|
||||||
if do_stream:
|
if do_stream:
|
||||||
|
status_redis.setp(str(worker_id), ('streaming', client_ip))
|
||||||
|
|
||||||
# Return the name of the stream that the slave should connect to.
|
# Return the name of the stream that the slave should connect to.
|
||||||
event = DataEvent(event_id)
|
|
||||||
event.set((True, get_stream_name(worker_id), None))
|
event.set((True, get_stream_name(worker_id), None))
|
||||||
|
|
||||||
msg_to_backend = {
|
msg_to_backend = {
|
||||||
|
@ -114,12 +136,12 @@ def worker(backend_url):
|
||||||
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
|
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
|
||||||
else:
|
else:
|
||||||
# Normal inference (not streaming).
|
# Normal inference (not streaming).
|
||||||
|
status_redis.setp(str(worker_id), ('generating', client_ip))
|
||||||
success, response, error_msg = generator(request_json_body, backend_url)
|
success, response, error_msg = generator(request_json_body, backend_url)
|
||||||
event = DataEvent(event_id)
|
|
||||||
event.set((success, response, error_msg))
|
event.set((success, response, error_msg))
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
logger.error(traceback.format_exc())
|
||||||
redis.publish(event_id, 'canceled')
|
event.set((False, None, 'exception'))
|
||||||
finally:
|
finally:
|
||||||
decrement_ip_count(client_ip, 'processing_ips')
|
decrement_ip_count(client_ip, 'processing_ips')
|
||||||
decr_active_workers(selected_model, backend_url)
|
decr_active_workers(selected_model, backend_url)
|
||||||
|
@ -127,6 +149,7 @@ def worker(backend_url):
|
||||||
|
|
||||||
|
|
||||||
def start_workers(cluster: dict):
|
def start_workers(cluster: dict):
|
||||||
|
logger = create_logger('inferencer')
|
||||||
i = 0
|
i = 0
|
||||||
for item in cluster:
|
for item in cluster:
|
||||||
for _ in range(item['concurrent_gens']):
|
for _ in range(item['concurrent_gens']):
|
||||||
|
@ -134,4 +157,4 @@ def start_workers(cluster: dict):
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
i += 1
|
i += 1
|
||||||
print(f'Started {i} inference workers.')
|
logger.info(f'Started {i} inference workers.')
|
||||||
|
|
|
@ -7,6 +7,7 @@ import redis as redis_redis
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
||||||
|
from llm_server.logging import create_logger
|
||||||
|
|
||||||
redis_moderation = redis_redis.Redis()
|
redis_moderation = redis_redis.Redis()
|
||||||
|
|
||||||
|
@ -18,7 +19,6 @@ def start_moderation_workers(num_workers):
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
i += 1
|
i += 1
|
||||||
print(f'Started {i} moderation workers.')
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: don't use UUID tags to identify items. Use native redis
|
# TODO: don't use UUID tags to identify items. Use native redis
|
||||||
|
@ -39,12 +39,13 @@ def get_results(tag, num_tasks):
|
||||||
flagged_categories.add(item)
|
flagged_categories.add(item)
|
||||||
num_results += 1
|
num_results += 1
|
||||||
if time.time() - start_time > opts.openai_moderation_timeout:
|
if time.time() - start_time > opts.openai_moderation_timeout:
|
||||||
print('----> Timed out waiting for result from moderator.')
|
logger.warning('Timed out waiting for result from moderator')
|
||||||
break
|
break
|
||||||
return list(flagged_categories)
|
return list(flagged_categories)
|
||||||
|
|
||||||
|
|
||||||
def moderation_worker():
|
def moderation_worker():
|
||||||
|
logger = create_logger('moderator')
|
||||||
while True:
|
while True:
|
||||||
result = redis_moderation.blpop(['queue:msgs_to_check'])
|
result = redis_moderation.blpop(['queue:msgs_to_check'])
|
||||||
try:
|
try:
|
||||||
|
@ -52,7 +53,7 @@ def moderation_worker():
|
||||||
_, categories = check_moderation_endpoint(msg)
|
_, categories = check_moderation_endpoint(msg)
|
||||||
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
logger.error(traceback.format_exc())
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,43 +1,24 @@
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from llm_server.cluster.backend import get_model_choices, get_running_models
|
from llm_server.cluster.backend import get_running_models
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
|
from llm_server.logging import create_logger
|
||||||
from llm_server.routes.queue import priority_queue
|
from llm_server.routes.queue import priority_queue
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
|
||||||
|
|
||||||
logger = logging.getLogger('console_printer')
|
|
||||||
if not logger.handlers:
|
|
||||||
handler = logging.StreamHandler()
|
|
||||||
handler.setLevel(logging.INFO)
|
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s")
|
|
||||||
handler.setFormatter(formatter)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
|
|
||||||
def console_printer():
|
def console_printer():
|
||||||
|
logger = create_logger('console_printer')
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
stats = generate_stats()
|
processing = redis.keys('active_gen_workers:http*') # backends always start with http
|
||||||
model_choices, default_model = get_model_choices()
|
|
||||||
|
|
||||||
processing_count = 0
|
processing_count = 0
|
||||||
backend_count = len(stats['backends'])
|
if len(processing):
|
||||||
|
for k in processing:
|
||||||
if model_choices and default_model:
|
processing_count += redis.get(k, default=0, dtype=int)
|
||||||
for model, info in model_choices.items():
|
backends = [k for k, v in cluster_config.all().items() if v['online']]
|
||||||
processing_count += info['processing']
|
|
||||||
|
|
||||||
# processing = redis.keys('active_gen_workers:http*') # backends always start with http
|
|
||||||
# processing_count = 0
|
|
||||||
# if len(processing):
|
|
||||||
# for k in processing:
|
|
||||||
# processing_count += redis.get(k, default=0, dtype=int)
|
|
||||||
# backends = [k for k, v in cluster_config.all().items() if v['online']]
|
|
||||||
activity = priority_queue.activity()
|
activity = priority_queue.activity()
|
||||||
|
|
||||||
# Calculate the queue size the same way it's done on the stats.
|
# Calculate the queue size the same way it's done on the stats.
|
||||||
|
@ -47,7 +28,7 @@ def console_printer():
|
||||||
queue_size += priority_queue.len(model)
|
queue_size += priority_queue.len(model)
|
||||||
|
|
||||||
# Active Workers and Processing should read the same. If not, that's an issue.
|
# Active Workers and Processing should read the same. If not, that's an issue.
|
||||||
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {backend_count}')
|
logger.info(f'Active Workers: {len([i for i in activity if (i[1] and i[1] != "waiting...")])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {len(backends)}')
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
logger.error(traceback.format_exc())
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
|
|
|
@ -3,6 +3,7 @@ from threading import Thread
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.cluster.worker import cluster_worker
|
from llm_server.cluster.worker import cluster_worker
|
||||||
|
from llm_server.logging import create_logger
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.workers.inferencer import start_workers
|
from llm_server.workers.inferencer import start_workers
|
||||||
from llm_server.workers.logger import db_logger
|
from llm_server.workers.logger import db_logger
|
||||||
|
@ -19,36 +20,39 @@ def cache_stats():
|
||||||
|
|
||||||
|
|
||||||
def start_background():
|
def start_background():
|
||||||
|
logger = create_logger('threader')
|
||||||
start_workers(opts.cluster)
|
start_workers(opts.cluster)
|
||||||
|
|
||||||
t = Thread(target=main_background_thread)
|
t = Thread(target=main_background_thread)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
print('Started the main background thread.')
|
logger.info('Started the main background thread.')
|
||||||
|
|
||||||
start_moderation_workers(opts.cluster_workers * 3)
|
num_moderators = opts.cluster_workers * 3
|
||||||
|
start_moderation_workers(num_moderators)
|
||||||
|
logger.info(f'Started {num_moderators} moderation workers.')
|
||||||
|
|
||||||
t = Thread(target=cache_stats)
|
t = Thread(target=cache_stats)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
print('Started the stats cacher.')
|
logger.info('Started the stats cacher.')
|
||||||
|
|
||||||
t = Thread(target=recent_prompters_thread)
|
t = Thread(target=recent_prompters_thread)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
print('Started the recent proompters thread.')
|
logger.info('Started the recent proompters thread.')
|
||||||
|
|
||||||
t = Thread(target=console_printer)
|
t = Thread(target=console_printer)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
print('Started the console printer.')
|
logger.info('Started the console logger.infoer.')
|
||||||
|
|
||||||
t = Thread(target=cluster_worker)
|
t = Thread(target=cluster_worker)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
print('Started the cluster worker.')
|
logger.info('Started the cluster worker.')
|
||||||
|
|
||||||
t = Thread(target=db_logger)
|
t = Thread(target=db_logger)
|
||||||
t.daemon = True
|
t.daemon = True
|
||||||
t.start()
|
t.start()
|
||||||
print('Started background logger.')
|
logger.info('Started background logger.')
|
||||||
|
|
|
@ -9,9 +9,10 @@ simplejson~=3.19.1
|
||||||
websockets~=11.0.3
|
websockets~=11.0.3
|
||||||
basicauth~=1.0.0
|
basicauth~=1.0.0
|
||||||
openai~=0.28.0
|
openai~=0.28.0
|
||||||
urllib3~=2.0.4
|
|
||||||
flask-sock==0.6.0
|
flask-sock==0.6.0
|
||||||
gunicorn==21.2.0
|
gunicorn==21.2.0
|
||||||
redis==5.0.1
|
redis==5.0.1
|
||||||
ujson==5.8.0
|
ujson==5.8.0
|
||||||
vllm
|
vllm==0.2.1.post1
|
||||||
|
gradio~=3.46.1
|
||||||
|
coloredlogs~=15.0.1
|
Reference in New Issue