Merge cluster to master #3

Merged
cyberes merged 163 commits from cluster into master 2023-10-27 19:19:22 -06:00
20 changed files with 187 additions and 94 deletions
Showing only changes of commit 0771c2325c - Show all commits

View File

@ -1,4 +1,5 @@
import argparse import argparse
import logging
import os import os
import sys import sys
import time import time
@ -10,6 +11,7 @@ from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.load import load_config, parse_backends from llm_server.config.load import load_config, parse_backends
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.database.create import create_db from llm_server.database.create import create_db
from llm_server.logging import create_logger, logging_info, init_logging
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background from llm_server.workers.threader import start_background
@ -21,18 +23,26 @@ else:
config_path = Path(script_path, 'config', 'config.yml') config_path = Path(script_path, 'config', 'config.yml')
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description='Daemon microservice.')
description='Daemon microservice.')
parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.") parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.")
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
args = parser.parse_args() args = parser.parse_args()
# TODO: have this be set by either the arg or a config value
if args.debug:
logging_info.level = logging.DEBUG
init_logging()
logger = create_logger('daemon')
logger.debug('Debug logging enabled.')
if not args.no_reset: if not args.no_reset:
Redis().flushall() Redis().flushall()
print('Flushed Redis.') logger.info('Flushed Redis.')
success, config, msg = load_config(config_path) success, config, msg = load_config(config_path)
if not success: if not success:
print('Failed to load config:', msg) logger.info(f'Failed to load config: {msg}')
sys.exit(1) sys.exit(1)
create_db() create_db()
@ -40,13 +50,13 @@ if __name__ == "__main__":
cluster_config.clear() cluster_config.clear()
cluster_config.load(parse_backends(config)) cluster_config.load(parse_backends(config))
print('Loading backend stats...') logger.info('Loading backend stats...')
generate_stats() generate_stats()
start_background() start_background()
redis.set('daemon_started', 1) redis.set('daemon_started', 1)
print('== Daemon Setup Complete ==\n') logger.info('== Daemon Setup Complete ==')
try: try:
while True: while True:

View File

@ -1,8 +1,8 @@
import time import time
from threading import Thread from threading import Thread
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.backend import test_backend from llm_server.cluster.backend import test_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.stores import redis_running_models from llm_server.cluster.stores import redis_running_models
@ -26,7 +26,6 @@ def cluster_worker():
def check_backend(n, v, test_prompt): def check_backend(n, v, test_prompt):
online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt) online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
# purge_backend_from_running_models(n)
if online: if online:
running_model = backend_info['model'] running_model = backend_info['model']
for k, v in backend_info.items(): for k, v in backend_info.items():
@ -36,7 +35,4 @@ def check_backend(n, v, test_prompt):
for model in redis_running_models.keys(): for model in redis_running_models.keys():
redis_running_models.srem(model, n) redis_running_models.srem(model, n)
# redis_running_models.srem(backend_info['model'], n)
# backend_cycler_store.lrem(backend_info['model'], 1, n)
cluster_config.set_backend_value(n, 'online', online) cluster_config.set_backend_value(n, 'online', online)

View File

@ -1,10 +1,28 @@
import tiktoken
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import oobabooga, vllm from llm_server.llm import oobabooga, vllm
from llm_server.logging import create_logger
def fallback_tokenizer(prompt: str):
tokenizer = tiktoken.get_encoding("cl100k_base")
return len(tokenizer.encode(prompt)) + 10
def get_token_count(prompt: str, backend_url: str): def get_token_count(prompt: str, backend_url: str):
backend_url = cluster_config.validate_backend(backend_url) backend_url = cluster_config.validate_backend(backend_url)
backend_mode = cluster_config.get_backend(backend_url)['mode'] if not backend_url:
logger = create_logger('tokenizer')
logger.warning('using fallback tokenizer as there is no valid backend')
return fallback_tokenizer(prompt)
backend_mode = cluster_config.get_backend(backend_url).get('mode')
if not backend_mode:
logger = create_logger('tokenizer')
logger.warning("using fallback tokenizer as the backend isn't initalized")
return fallback_tokenizer(prompt)
if backend_mode == 'vllm': if backend_mode == 'vllm':
return vllm.tokenize(prompt, backend_url) return vllm.tokenize(prompt, backend_url)
elif backend_mode == 'ooba': elif backend_mode == 'ooba':

View File

@ -5,6 +5,7 @@ import tiktoken
from llm_server import opts from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.logging import create_logger
def tokenize(prompt: str, backend_url: str) -> int: def tokenize(prompt: str, backend_url: str) -> int:
@ -16,6 +17,8 @@ def tokenize(prompt: str, backend_url: str) -> int:
return 0 return 0
assert isinstance(prompt, str) assert isinstance(prompt, str)
logger = create_logger('tokenizer')
# The backend could have died between when the request was # The backend could have died between when the request was
# submitted and now, so let's double check it's still online. # submitted and now, so let's double check it's still online.
backend_url = cluster_config.validate_backend(backend_url) backend_url = cluster_config.validate_backend(backend_url)
@ -33,7 +36,7 @@ def tokenize(prompt: str, backend_url: str) -> int:
j = r.json() j = r.json()
return j['length'] return j['length']
except Exception as e: except Exception as e:
print(f'Failed to tokenize using VLLM - {e.__class__.__name__}') logger.debug(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
return len(tokenizer.encode(chunk)) + 10 return len(tokenizer.encode(chunk)) + 10
# Use a ThreadPoolExecutor to send all chunks to the server at once # Use a ThreadPoolExecutor to send all chunks to the server at once
@ -44,5 +47,5 @@ def tokenize(prompt: str, backend_url: str) -> int:
try: try:
data = future.result() data = future.result()
except Exception as exc: except Exception as exc:
print('%r generated an exception: %s' % (chunk, exc)) logger.warning('%r generated an exception: %s' % (chunk, exc))
return sum(future.result() for future in future_to_chunk) return sum(future.result() for future in future_to_chunk)

52
llm_server/logging.py Normal file
View File

@ -0,0 +1,52 @@
import logging
import coloredlogs
from llm_server import opts
class LoggingInfo:
def __init__(self):
self._level = logging.INFO
self._format = opts.LOGGING_FORMAT
@property
def level(self):
return self._level
@level.setter
def level(self, value):
self._level = value
@property
def format(self):
return self._format
@format.setter
def format(self, value):
self._format = value
logging_info = LoggingInfo()
def init_logging():
"""
Set up the parent logger.
:return:
"""
logger = logging.getLogger('llm_server')
logger.setLevel(logging_info.level)
def create_logger(name):
logger = logging.getLogger('llm_server').getChild(name)
logger.setLevel(logging_info.level)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging_info.level)
formatter = logging.Formatter(logging_info.format)
handler.setFormatter(formatter)
logger.addHandler(handler)
coloredlogs.install(logger=logger, level=logging_info.level)
return logger

View File

@ -1,5 +1,8 @@
# Read-only global variables # Read-only global variables
# Uppercase variables are read-only globals.
# Lowercase variables are ones that are set on startup and are never changed.
# TODO: rewrite the config system so I don't have to add every single config default here # TODO: rewrite the config system so I don't have to add every single config default here
frontend_api_mode = 'ooba' frontend_api_mode = 'ooba'
@ -39,3 +42,5 @@ openai_moderation_timeout = 5
prioritize_by_size = False prioritize_by_size = False
cluster_workers = 0 cluster_workers = 0
redis_stream_timeout = 25000 redis_stream_timeout = 25000
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"

View File

@ -16,7 +16,7 @@ class OobaRequestHandler(RequestHandler):
def handle_request(self, return_ok: bool = True): def handle_request(self, return_ok: bool = True):
assert not self.used assert not self.used
if self.offline: if self.offline:
print(messages.BACKEND_OFFLINE) print('This backend is offline:', messages.BACKEND_OFFLINE)
return self.handle_error(messages.BACKEND_OFFLINE) return self.handle_error(messages.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request() request_valid, invalid_response = self.validate_request()

View File

@ -24,7 +24,7 @@ def handle_error(e):
"auth_subrequest_error" "auth_subrequest_error"
""" """
print(e) print('OAI returning error:', e)
return jsonify({ return jsonify({
"error": { "error": {
"message": "Internal server error", "message": "Internal server error",

View File

@ -29,9 +29,7 @@ def openai_chat_completions(model_name=None):
else: else:
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline: if handler.offline:
msg = return_invalid_model_err(model_name) return return_invalid_model_err(model_name)
print(msg)
return handler.handle_error(msg)
if not request_json_body.get('stream'): if not request_json_body.get('stream'):
try: try:
@ -100,7 +98,8 @@ def openai_chat_completions(model_name=None):
# Need to do this before we enter generate() since we want to be able to # Need to do this before we enter generate() since we want to be able to
# return a 408 if necessary. # return a 408 if necessary.
_, stream_name, error_msg = event.wait() _, stream_name, error_msg = event.wait()
if error_msg == 'closed': if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None # set to null so that the Finally ignores it. stream_name = None # set to null so that the Finally ignores it.
return 'Request Timeout', 408 return 'Request Timeout', 408
@ -120,7 +119,8 @@ def openai_chat_completions(model_name=None):
timestamp = int(stream_index.decode('utf-8').split('-')[0]) timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data']) data = ujson.loads(item[b'data'])
if data['error']: if data['error']:
print('OAI streaming error:', data['error']) # Not printing error since we can just check the daemon log.
print('OAI streaming encountered error')
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
return return
elif data['new']: elif data['new']:

View File

@ -29,9 +29,7 @@ def openai_completions(model_name=None):
else: else:
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name) handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline: if handler.offline:
msg = return_invalid_model_err(model_name) return return_invalid_model_err(model_name)
print(msg)
return handler.handle_error(msg)
if handler.cluster_backend_info['mode'] != 'vllm': if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends # TODO: implement other backends
@ -145,7 +143,8 @@ def openai_completions(model_name=None):
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
_, stream_name, error_msg = event.wait() _, stream_name, error_msg = event.wait()
if error_msg == 'closed': if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None stream_name = None
return 'Request Timeout', 408 return 'Request Timeout', 408
@ -165,7 +164,7 @@ def openai_completions(model_name=None):
timestamp = int(stream_index.decode('utf-8').split('-')[0]) timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data']) data = ujson.loads(item[b'data'])
if data['error']: if data['error']:
print('OAI streaming error:', data['error']) print('OAI streaming encountered error')
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
return return
elif data['new']: elif data['new']:

View File

@ -29,7 +29,7 @@ class OpenAIRequestHandler(RequestHandler):
assert not self.used assert not self.used
if self.offline: if self.offline:
msg = return_invalid_model_err(self.selected_model) msg = return_invalid_model_err(self.selected_model)
print(msg) print('OAI Offline:', msg)
return self.handle_error(msg) return self.handle_error(msg)
if opts.openai_silent_trim: if opts.openai_silent_trim:
@ -106,7 +106,7 @@ class OpenAIRequestHandler(RequestHandler):
return response, 429 return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
print(error_msg) print('OAI Error:', error_msg)
return jsonify({ return jsonify({
"error": { "error": {
"message": "Invalid request, check your parameters and try again.", "message": "Invalid request, check your parameters and try again.",

View File

@ -37,8 +37,6 @@ class RedisPriorityQueue:
assert priority is not None assert priority is not None
assert selected_model is not None assert selected_model is not None
event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit # Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.get_ip_request_count(item[1]) ip_count = self.get_ip_request_count(item[1])
_, simultaneous_ip = get_token_ratelimit(item[2]) _, simultaneous_ip = get_token_ratelimit(item[2])
@ -47,6 +45,7 @@ class RedisPriorityQueue:
return None # reject the request return None # reject the request
timestamp = time.time() timestamp = time.time()
event = DataEvent()
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority}) self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
return event return event
@ -107,7 +106,7 @@ class DataEvent:
Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol). Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol).
""" """
def __init__(self, event_id=None): def __init__(self, event_id: str = None):
self.event_id = event_id if event_id else str(uuid4()) self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14) self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub() self.pubsub = self.redis.pubsub()
@ -118,7 +117,6 @@ class DataEvent:
def wait(self): def wait(self):
for item in self.pubsub.listen(): for item in self.pubsub.listen():
print(item)
if item['type'] == 'message': if item['type'] == 'message':
return pickle.loads(item['data']) return pickle.loads(item['data'])

View File

@ -44,12 +44,13 @@ class RequestHandler:
self.backend_url = get_a_cluster_backend(selected_model) self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url) self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
if not self.cluster_backend_info.get('mode'): # Debug stuff
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info) # if not self.cluster_backend_info.get('mode'):
if not self.cluster_backend_info.get('model'): # print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info) # if not self.cluster_backend_info.get('model'):
if not self.cluster_backend_info.get('model_config'): # print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info) # if not self.cluster_backend_info.get('model_config'):
# print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'): if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
self.offline = True self.offline = True

View File

@ -1,3 +1,3 @@
def handle_server_error(e): def handle_server_error(e):
print(e) print('Internal Error:', e)
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500 return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500

View File

@ -130,7 +130,8 @@ def do_stream(ws, model_name):
event_id = event.event_id event_id = event.event_id
_, stream_name, error_msg = event.wait() _, stream_name, error_msg = event.wait()
if error_msg == 'closed': if error_msg:
print('Stream failed to start streaming:', error_msg)
ws.close(reason=1014, message='Request Timeout') ws.close(reason=1014, message='Request Timeout')
return return

View File

@ -10,6 +10,7 @@ from redis import Redis
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis from llm_server.custom_redis import RedisCustom, redis
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.logging import create_logger
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
stream_redis = Redis(db=8) stream_redis = Redis(db=8)
@ -39,6 +40,7 @@ def get_stream_name(name: str):
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str): def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
logger = create_logger('inferencer')
prompt = msg_to_backend['prompt'] prompt = msg_to_backend['prompt']
stream_name = get_stream_name(stream_name) stream_name = get_stream_name(stream_name)
stream_redis.delete(get_stream_name(stream_name)) # be extra sure stream_redis.delete(get_stream_name(stream_name)) # be extra sure
@ -53,7 +55,7 @@ def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str
if not chunk: if not chunk:
break break
if event.is_set(): if event.is_set():
print('Client canceled generation') logger.debug('Client canceled generation')
response.close() response.close()
return return
@ -70,40 +72,60 @@ def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str
# ???? # ????
continue continue
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})}) stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})})
except AttributeError as e:
if str(e) == "'bool' object has no attribute 'iter_content'":
# We don't care about these errors.
logger.debug('failed to stream from backend - no response')
else:
raise
except Exception as e: except Exception as e:
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})}) stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
traceback.print_exc() raise # We won't handle the exception here.
finally: finally:
# Publish final message to Redis stream # Publish final message to Redis stream
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})}) stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})})
event.set() # stop the cancellation checking thread event.set() # stop the cancellation checking thread
#
def worker(backend_url): def worker(backend_url):
logger = create_logger('inferencer')
status_redis = RedisCustom('worker_status') status_redis = RedisCustom('worker_status')
worker_id = str(uuid4()) worker_id = str(uuid4())
status_redis.setp(str(worker_id), None) status_redis.setp(str(worker_id), None)
redis_queue = RedisPriorityQueue(backend_url) redis_queue = RedisPriorityQueue(backend_url)
while True: while True:
status_redis.setp(str(worker_id), 'waiting...')
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get() (request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
event = DataEvent(event_id)
try: try:
backend_info = cluster_config.get_backend(backend_url) backend_info = cluster_config.get_backend(backend_url)
except:
# This is not a critical error because it usually means that the backend is
# offline and this backend is in a state of transition from online to offline.
logger.debug(f'got an exception while getting info for backend {backend_url} - ', traceback.format_exc())
event.set((False, None, 'exception'))
continue
if not backend_info['online']: if not backend_info['online']:
redis.publish(event_id, 'canceled') event.set((False, None, 'canceled'))
return continue
if not selected_model: if not selected_model:
selected_model = backend_info['model'] selected_model = backend_info['model']
logger.debug(f"Starting using {backend_url} and {selected_model}. Online: {backend_info['online']}")
try:
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
increment_ip_count(client_ip, 'processing_ips') increment_ip_count(client_ip, 'processing_ips')
incr_active_workers(selected_model, backend_url) incr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), ('generating', client_ip))
if do_stream: if do_stream:
status_redis.setp(str(worker_id), ('streaming', client_ip))
# Return the name of the stream that the slave should connect to. # Return the name of the stream that the slave should connect to.
event = DataEvent(event_id)
event.set((True, get_stream_name(worker_id), None)) event.set((True, get_stream_name(worker_id), None))
msg_to_backend = { msg_to_backend = {
@ -114,12 +136,12 @@ def worker(backend_url):
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id) inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
else: else:
# Normal inference (not streaming). # Normal inference (not streaming).
status_redis.setp(str(worker_id), ('generating', client_ip))
success, response, error_msg = generator(request_json_body, backend_url) success, response, error_msg = generator(request_json_body, backend_url)
event = DataEvent(event_id)
event.set((success, response, error_msg)) event.set((success, response, error_msg))
except: except:
traceback.print_exc() logger.error(traceback.format_exc())
redis.publish(event_id, 'canceled') event.set((False, None, 'exception'))
finally: finally:
decrement_ip_count(client_ip, 'processing_ips') decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url) decr_active_workers(selected_model, backend_url)
@ -127,6 +149,7 @@ def worker(backend_url):
def start_workers(cluster: dict): def start_workers(cluster: dict):
logger = create_logger('inferencer')
i = 0 i = 0
for item in cluster: for item in cluster:
for _ in range(item['concurrent_gens']): for _ in range(item['concurrent_gens']):
@ -134,4 +157,4 @@ def start_workers(cluster: dict):
t.daemon = True t.daemon = True
t.start() t.start()
i += 1 i += 1
print(f'Started {i} inference workers.') logger.info(f'Started {i} inference workers.')

View File

@ -7,6 +7,7 @@ import redis as redis_redis
from llm_server import opts from llm_server import opts
from llm_server.llm.openai.moderation import check_moderation_endpoint from llm_server.llm.openai.moderation import check_moderation_endpoint
from llm_server.logging import create_logger
redis_moderation = redis_redis.Redis() redis_moderation = redis_redis.Redis()
@ -18,7 +19,6 @@ def start_moderation_workers(num_workers):
t.daemon = True t.daemon = True
t.start() t.start()
i += 1 i += 1
print(f'Started {i} moderation workers.')
# TODO: don't use UUID tags to identify items. Use native redis # TODO: don't use UUID tags to identify items. Use native redis
@ -39,12 +39,13 @@ def get_results(tag, num_tasks):
flagged_categories.add(item) flagged_categories.add(item)
num_results += 1 num_results += 1
if time.time() - start_time > opts.openai_moderation_timeout: if time.time() - start_time > opts.openai_moderation_timeout:
print('----> Timed out waiting for result from moderator.') logger.warning('Timed out waiting for result from moderator')
break break
return list(flagged_categories) return list(flagged_categories)
def moderation_worker(): def moderation_worker():
logger = create_logger('moderator')
while True: while True:
result = redis_moderation.blpop(['queue:msgs_to_check']) result = redis_moderation.blpop(['queue:msgs_to_check'])
try: try:
@ -52,7 +53,7 @@ def moderation_worker():
_, categories = check_moderation_endpoint(msg) _, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories))) redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except: except:
traceback.print_exc() logger.error(traceback.format_exc())
continue continue

View File

@ -1,43 +1,24 @@
import logging
import time import time
import traceback import traceback
from llm_server.cluster.backend import get_model_choices, get_running_models from llm_server.cluster.backend import get_running_models
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from llm_server.logging import create_logger
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
from llm_server.routes.v1.generate_stats import generate_stats
logger = logging.getLogger('console_printer')
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
def console_printer(): def console_printer():
logger = create_logger('console_printer')
time.sleep(3) time.sleep(3)
while True: while True:
try: try:
stats = generate_stats() processing = redis.keys('active_gen_workers:http*') # backends always start with http
model_choices, default_model = get_model_choices()
processing_count = 0 processing_count = 0
backend_count = len(stats['backends']) if len(processing):
for k in processing:
if model_choices and default_model: processing_count += redis.get(k, default=0, dtype=int)
for model, info in model_choices.items(): backends = [k for k, v in cluster_config.all().items() if v['online']]
processing_count += info['processing']
# processing = redis.keys('active_gen_workers:http*') # backends always start with http
# processing_count = 0
# if len(processing):
# for k in processing:
# processing_count += redis.get(k, default=0, dtype=int)
# backends = [k for k, v in cluster_config.all().items() if v['online']]
activity = priority_queue.activity() activity = priority_queue.activity()
# Calculate the queue size the same way it's done on the stats. # Calculate the queue size the same way it's done on the stats.
@ -47,7 +28,7 @@ def console_printer():
queue_size += priority_queue.len(model) queue_size += priority_queue.len(model)
# Active Workers and Processing should read the same. If not, that's an issue. # Active Workers and Processing should read the same. If not, that's an issue.
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {backend_count}') logger.info(f'Active Workers: {len([i for i in activity if (i[1] and i[1] != "waiting...")])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {len(backends)}')
except: except:
traceback.print_exc() logger.error(traceback.format_exc())
time.sleep(10) time.sleep(10)

View File

@ -3,6 +3,7 @@ from threading import Thread
from llm_server import opts from llm_server import opts
from llm_server.cluster.worker import cluster_worker from llm_server.cluster.worker import cluster_worker
from llm_server.logging import create_logger
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers from llm_server.workers.inferencer import start_workers
from llm_server.workers.logger import db_logger from llm_server.workers.logger import db_logger
@ -19,36 +20,39 @@ def cache_stats():
def start_background(): def start_background():
logger = create_logger('threader')
start_workers(opts.cluster) start_workers(opts.cluster)
t = Thread(target=main_background_thread) t = Thread(target=main_background_thread)
t.daemon = True t.daemon = True
t.start() t.start()
print('Started the main background thread.') logger.info('Started the main background thread.')
start_moderation_workers(opts.cluster_workers * 3) num_moderators = opts.cluster_workers * 3
start_moderation_workers(num_moderators)
logger.info(f'Started {num_moderators} moderation workers.')
t = Thread(target=cache_stats) t = Thread(target=cache_stats)
t.daemon = True t.daemon = True
t.start() t.start()
print('Started the stats cacher.') logger.info('Started the stats cacher.')
t = Thread(target=recent_prompters_thread) t = Thread(target=recent_prompters_thread)
t.daemon = True t.daemon = True
t.start() t.start()
print('Started the recent proompters thread.') logger.info('Started the recent proompters thread.')
t = Thread(target=console_printer) t = Thread(target=console_printer)
t.daemon = True t.daemon = True
t.start() t.start()
print('Started the console printer.') logger.info('Started the console logger.infoer.')
t = Thread(target=cluster_worker) t = Thread(target=cluster_worker)
t.daemon = True t.daemon = True
t.start() t.start()
print('Started the cluster worker.') logger.info('Started the cluster worker.')
t = Thread(target=db_logger) t = Thread(target=db_logger)
t.daemon = True t.daemon = True
t.start() t.start()
print('Started background logger.') logger.info('Started background logger.')

View File

@ -9,9 +9,10 @@ simplejson~=3.19.1
websockets~=11.0.3 websockets~=11.0.3
basicauth~=1.0.0 basicauth~=1.0.0
openai~=0.28.0 openai~=0.28.0
urllib3~=2.0.4
flask-sock==0.6.0 flask-sock==0.6.0
gunicorn==21.2.0 gunicorn==21.2.0
redis==5.0.1 redis==5.0.1
ujson==5.8.0 ujson==5.8.0
vllm vllm==0.2.1.post1
gradio~=3.46.1
coloredlogs~=15.0.1