fix inference workers quitting when a backend is offline, start adding logging, improve tokenizer error handling

This commit is contained in:
Cyberes 2023-10-23 17:24:20 -06:00
parent 3cf73fec9b
commit 0771c2325c
20 changed files with 187 additions and 94 deletions

View File

@ -1,4 +1,5 @@
import argparse
import logging
import os
import sys
import time
@ -10,6 +11,7 @@ from llm_server.cluster.cluster_config import cluster_config
from llm_server.config.load import load_config, parse_backends
from llm_server.custom_redis import redis
from llm_server.database.create import create_db
from llm_server.logging import create_logger, logging_info, init_logging
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.threader import start_background
@ -21,18 +23,26 @@ else:
config_path = Path(script_path, 'config', 'config.yml')
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Daemon microservice.')
parser = argparse.ArgumentParser(description='Daemon microservice.')
parser.add_argument('--no-reset', action='store_true', help="Don't clear the Redis server databases.")
parser.add_argument('-d', '--debug', action='store_true', help='Enable debug logging.')
args = parser.parse_args()
# TODO: have this be set by either the arg or a config value
if args.debug:
logging_info.level = logging.DEBUG
init_logging()
logger = create_logger('daemon')
logger.debug('Debug logging enabled.')
if not args.no_reset:
Redis().flushall()
print('Flushed Redis.')
logger.info('Flushed Redis.')
success, config, msg = load_config(config_path)
if not success:
print('Failed to load config:', msg)
logger.info(f'Failed to load config: {msg}')
sys.exit(1)
create_db()
@ -40,13 +50,13 @@ if __name__ == "__main__":
cluster_config.clear()
cluster_config.load(parse_backends(config))
print('Loading backend stats...')
logger.info('Loading backend stats...')
generate_stats()
start_background()
redis.set('daemon_started', 1)
print('== Daemon Setup Complete ==\n')
logger.info('== Daemon Setup Complete ==')
try:
while True:

View File

@ -1,8 +1,8 @@
import time
from threading import Thread
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.backend import test_backend
from llm_server.cluster.cluster_config import cluster_config
from llm_server.cluster.stores import redis_running_models
@ -26,7 +26,6 @@ def cluster_worker():
def check_backend(n, v, test_prompt):
online, backend_info = test_backend(v['backend_url'], test_prompt=test_prompt)
# purge_backend_from_running_models(n)
if online:
running_model = backend_info['model']
for k, v in backend_info.items():
@ -36,7 +35,4 @@ def check_backend(n, v, test_prompt):
for model in redis_running_models.keys():
redis_running_models.srem(model, n)
# redis_running_models.srem(backend_info['model'], n)
# backend_cycler_store.lrem(backend_info['model'], 1, n)
cluster_config.set_backend_value(n, 'online', online)

View File

@ -1,10 +1,28 @@
import tiktoken
from llm_server.cluster.cluster_config import cluster_config
from llm_server.llm import oobabooga, vllm
from llm_server.logging import create_logger
def fallback_tokenizer(prompt: str):
tokenizer = tiktoken.get_encoding("cl100k_base")
return len(tokenizer.encode(prompt)) + 10
def get_token_count(prompt: str, backend_url: str):
backend_url = cluster_config.validate_backend(backend_url)
backend_mode = cluster_config.get_backend(backend_url)['mode']
if not backend_url:
logger = create_logger('tokenizer')
logger.warning('using fallback tokenizer as there is no valid backend')
return fallback_tokenizer(prompt)
backend_mode = cluster_config.get_backend(backend_url).get('mode')
if not backend_mode:
logger = create_logger('tokenizer')
logger.warning("using fallback tokenizer as the backend isn't initalized")
return fallback_tokenizer(prompt)
if backend_mode == 'vllm':
return vllm.tokenize(prompt, backend_url)
elif backend_mode == 'ooba':

View File

@ -5,6 +5,7 @@ import tiktoken
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.logging import create_logger
def tokenize(prompt: str, backend_url: str) -> int:
@ -16,6 +17,8 @@ def tokenize(prompt: str, backend_url: str) -> int:
return 0
assert isinstance(prompt, str)
logger = create_logger('tokenizer')
# The backend could have died between when the request was
# submitted and now, so let's double check it's still online.
backend_url = cluster_config.validate_backend(backend_url)
@ -33,7 +36,7 @@ def tokenize(prompt: str, backend_url: str) -> int:
j = r.json()
return j['length']
except Exception as e:
print(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
logger.debug(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
return len(tokenizer.encode(chunk)) + 10
# Use a ThreadPoolExecutor to send all chunks to the server at once
@ -44,5 +47,5 @@ def tokenize(prompt: str, backend_url: str) -> int:
try:
data = future.result()
except Exception as exc:
print('%r generated an exception: %s' % (chunk, exc))
logger.warning('%r generated an exception: %s' % (chunk, exc))
return sum(future.result() for future in future_to_chunk)

52
llm_server/logging.py Normal file
View File

@ -0,0 +1,52 @@
import logging
import coloredlogs
from llm_server import opts
class LoggingInfo:
def __init__(self):
self._level = logging.INFO
self._format = opts.LOGGING_FORMAT
@property
def level(self):
return self._level
@level.setter
def level(self, value):
self._level = value
@property
def format(self):
return self._format
@format.setter
def format(self, value):
self._format = value
logging_info = LoggingInfo()
def init_logging():
"""
Set up the parent logger.
:return:
"""
logger = logging.getLogger('llm_server')
logger.setLevel(logging_info.level)
def create_logger(name):
logger = logging.getLogger('llm_server').getChild(name)
logger.setLevel(logging_info.level)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging_info.level)
formatter = logging.Formatter(logging_info.format)
handler.setFormatter(formatter)
logger.addHandler(handler)
coloredlogs.install(logger=logger, level=logging_info.level)
return logger

View File

@ -1,5 +1,8 @@
# Read-only global variables
# Uppercase variables are read-only globals.
# Lowercase variables are ones that are set on startup and are never changed.
# TODO: rewrite the config system so I don't have to add every single config default here
frontend_api_mode = 'ooba'
@ -39,3 +42,5 @@ openai_moderation_timeout = 5
prioritize_by_size = False
cluster_workers = 0
redis_stream_timeout = 25000
LOGGING_FORMAT = "%(asctime)s: %(levelname)s:%(name)s - %(message)s"

View File

@ -16,7 +16,7 @@ class OobaRequestHandler(RequestHandler):
def handle_request(self, return_ok: bool = True):
assert not self.used
if self.offline:
print(messages.BACKEND_OFFLINE)
print('This backend is offline:', messages.BACKEND_OFFLINE)
return self.handle_error(messages.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request()

View File

@ -24,7 +24,7 @@ def handle_error(e):
"auth_subrequest_error"
"""
print(e)
print('OAI returning error:', e)
return jsonify({
"error": {
"message": "Internal server error",

View File

@ -29,9 +29,7 @@ def openai_chat_completions(model_name=None):
else:
handler = OpenAIRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline:
msg = return_invalid_model_err(model_name)
print(msg)
return handler.handle_error(msg)
return return_invalid_model_err(model_name)
if not request_json_body.get('stream'):
try:
@ -100,7 +98,8 @@ def openai_chat_completions(model_name=None):
# Need to do this before we enter generate() since we want to be able to
# return a 408 if necessary.
_, stream_name, error_msg = event.wait()
if error_msg == 'closed':
if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None # set to null so that the Finally ignores it.
return 'Request Timeout', 408
@ -120,7 +119,8 @@ def openai_chat_completions(model_name=None):
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data'])
if data['error']:
print('OAI streaming error:', data['error'])
# Not printing error since we can just check the daemon log.
print('OAI streaming encountered error')
yield 'data: [DONE]\n\n'
return
elif data['new']:

View File

@ -29,9 +29,7 @@ def openai_completions(model_name=None):
else:
handler = OobaRequestHandler(incoming_request=request, incoming_json=request_json_body, selected_model=model_name)
if handler.offline:
msg = return_invalid_model_err(model_name)
print(msg)
return handler.handle_error(msg)
return return_invalid_model_err(model_name)
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
@ -145,7 +143,8 @@ def openai_completions(model_name=None):
oai_string = generate_oai_string(30)
_, stream_name, error_msg = event.wait()
if error_msg == 'closed':
if error_msg:
print('OAI failed to start streaming:', error_msg)
stream_name = None
return 'Request Timeout', 408
@ -165,7 +164,7 @@ def openai_completions(model_name=None):
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = ujson.loads(item[b'data'])
if data['error']:
print('OAI streaming error:', data['error'])
print('OAI streaming encountered error')
yield 'data: [DONE]\n\n'
return
elif data['new']:

View File

@ -29,7 +29,7 @@ class OpenAIRequestHandler(RequestHandler):
assert not self.used
if self.offline:
msg = return_invalid_model_err(self.selected_model)
print(msg)
print('OAI Offline:', msg)
return self.handle_error(msg)
if opts.openai_silent_trim:
@ -106,7 +106,7 @@ class OpenAIRequestHandler(RequestHandler):
return response, 429
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
print(error_msg)
print('OAI Error:', error_msg)
return jsonify({
"error": {
"message": "Invalid request, check your parameters and try again.",

View File

@ -37,8 +37,6 @@ class RedisPriorityQueue:
assert priority is not None
assert selected_model is not None
event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.get_ip_request_count(item[1])
_, simultaneous_ip = get_token_ratelimit(item[2])
@ -47,6 +45,7 @@ class RedisPriorityQueue:
return None # reject the request
timestamp = time.time()
event = DataEvent()
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
return event
@ -107,7 +106,7 @@ class DataEvent:
Class to simplify pub/sub communication between consumers and producers (MASTERS and SLAVES lololololol).
"""
def __init__(self, event_id=None):
def __init__(self, event_id: str = None):
self.event_id = event_id if event_id else str(uuid4())
self.redis = Redis(host='localhost', port=6379, db=14)
self.pubsub = self.redis.pubsub()
@ -118,7 +117,6 @@ class DataEvent:
def wait(self):
for item in self.pubsub.listen():
print(item)
if item['type'] == 'message':
return pickle.loads(item['data'])

View File

@ -44,12 +44,13 @@ class RequestHandler:
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
if not self.cluster_backend_info.get('mode'):
print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('model'):
print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('model_config'):
print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
# Debug stuff
# if not self.cluster_backend_info.get('mode'):
# print('keyerror: mode -', selected_model, self.backend_url, self.cluster_backend_info)
# if not self.cluster_backend_info.get('model'):
# print('keyerror: model -', selected_model, self.backend_url, self.cluster_backend_info)
# if not self.cluster_backend_info.get('model_config'):
# print('keyerror: model_config -', selected_model, self.backend_url, self.cluster_backend_info)
if not self.cluster_backend_info.get('mode') or not self.cluster_backend_info.get('model') or not self.cluster_backend_info.get('model_config'):
self.offline = True

View File

@ -1,3 +1,3 @@
def handle_server_error(e):
print(e)
print('Internal Error:', e)
return {'error': True, 'code': 500, 'message': 'Internal Server Error :('}, 500

View File

@ -130,7 +130,8 @@ def do_stream(ws, model_name):
event_id = event.event_id
_, stream_name, error_msg = event.wait()
if error_msg == 'closed':
if error_msg:
print('Stream failed to start streaming:', error_msg)
ws.close(reason=1014, message='Request Timeout')
return

View File

@ -10,6 +10,7 @@ from redis import Redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis
from llm_server.llm.generator import generator
from llm_server.logging import create_logger
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
stream_redis = Redis(db=8)
@ -39,6 +40,7 @@ def get_stream_name(name: str):
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
logger = create_logger('inferencer')
prompt = msg_to_backend['prompt']
stream_name = get_stream_name(stream_name)
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
@ -53,7 +55,7 @@ def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str
if not chunk:
break
if event.is_set():
print('Client canceled generation')
logger.debug('Client canceled generation')
response.close()
return
@ -70,40 +72,60 @@ def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str
# ????
continue
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})})
except AttributeError as e:
if str(e) == "'bool' object has no attribute 'iter_content'":
# We don't care about these errors.
logger.debug('failed to stream from backend - no response')
else:
raise
except Exception as e:
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
traceback.print_exc()
raise # We won't handle the exception here.
finally:
# Publish final message to Redis stream
stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})})
event.set() # stop the cancellation checking thread
#
def worker(backend_url):
logger = create_logger('inferencer')
status_redis = RedisCustom('worker_status')
worker_id = str(uuid4())
status_redis.setp(str(worker_id), None)
redis_queue = RedisPriorityQueue(backend_url)
while True:
status_redis.setp(str(worker_id), 'waiting...')
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
event = DataEvent(event_id)
try:
backend_info = cluster_config.get_backend(backend_url)
except:
# This is not a critical error because it usually means that the backend is
# offline and this backend is in a state of transition from online to offline.
logger.debug(f'got an exception while getting info for backend {backend_url} - ', traceback.format_exc())
event.set((False, None, 'exception'))
continue
if not backend_info['online']:
redis.publish(event_id, 'canceled')
return
event.set((False, None, 'canceled'))
continue
if not selected_model:
selected_model = backend_info['model']
logger.debug(f"Starting using {backend_url} and {selected_model}. Online: {backend_info['online']}")
try:
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), ('generating', client_ip))
if do_stream:
status_redis.setp(str(worker_id), ('streaming', client_ip))
# Return the name of the stream that the slave should connect to.
event = DataEvent(event_id)
event.set((True, get_stream_name(worker_id), None))
msg_to_backend = {
@ -114,12 +136,12 @@ def worker(backend_url):
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
else:
# Normal inference (not streaming).
status_redis.setp(str(worker_id), ('generating', client_ip))
success, response, error_msg = generator(request_json_body, backend_url)
event = DataEvent(event_id)
event.set((success, response, error_msg))
except:
traceback.print_exc()
redis.publish(event_id, 'canceled')
logger.error(traceback.format_exc())
event.set((False, None, 'exception'))
finally:
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url)
@ -127,6 +149,7 @@ def worker(backend_url):
def start_workers(cluster: dict):
logger = create_logger('inferencer')
i = 0
for item in cluster:
for _ in range(item['concurrent_gens']):
@ -134,4 +157,4 @@ def start_workers(cluster: dict):
t.daemon = True
t.start()
i += 1
print(f'Started {i} inference workers.')
logger.info(f'Started {i} inference workers.')

View File

@ -7,6 +7,7 @@ import redis as redis_redis
from llm_server import opts
from llm_server.llm.openai.moderation import check_moderation_endpoint
from llm_server.logging import create_logger
redis_moderation = redis_redis.Redis()
@ -18,7 +19,6 @@ def start_moderation_workers(num_workers):
t.daemon = True
t.start()
i += 1
print(f'Started {i} moderation workers.')
# TODO: don't use UUID tags to identify items. Use native redis
@ -39,12 +39,13 @@ def get_results(tag, num_tasks):
flagged_categories.add(item)
num_results += 1
if time.time() - start_time > opts.openai_moderation_timeout:
print('----> Timed out waiting for result from moderator.')
logger.warning('Timed out waiting for result from moderator')
break
return list(flagged_categories)
def moderation_worker():
logger = create_logger('moderator')
while True:
result = redis_moderation.blpop(['queue:msgs_to_check'])
try:
@ -52,7 +53,7 @@ def moderation_worker():
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
traceback.print_exc()
logger.error(traceback.format_exc())
continue

View File

@ -1,43 +1,24 @@
import logging
import time
import traceback
from llm_server.cluster.backend import get_model_choices, get_running_models
from llm_server.cluster.backend import get_running_models
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.logging import create_logger
from llm_server.routes.queue import priority_queue
from llm_server.routes.v1.generate_stats import generate_stats
logger = logging.getLogger('console_printer')
if not logger.handlers:
handler = logging.StreamHandler()
handler.setLevel(logging.INFO)
logger.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s: %(levelname)s:%(name)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)
def console_printer():
logger = create_logger('console_printer')
time.sleep(3)
while True:
try:
stats = generate_stats()
model_choices, default_model = get_model_choices()
processing = redis.keys('active_gen_workers:http*') # backends always start with http
processing_count = 0
backend_count = len(stats['backends'])
if model_choices and default_model:
for model, info in model_choices.items():
processing_count += info['processing']
# processing = redis.keys('active_gen_workers:http*') # backends always start with http
# processing_count = 0
# if len(processing):
# for k in processing:
# processing_count += redis.get(k, default=0, dtype=int)
# backends = [k for k, v in cluster_config.all().items() if v['online']]
if len(processing):
for k in processing:
processing_count += redis.get(k, default=0, dtype=int)
backends = [k for k, v in cluster_config.all().items() if v['online']]
activity = priority_queue.activity()
# Calculate the queue size the same way it's done on the stats.
@ -47,7 +28,7 @@ def console_printer():
queue_size += priority_queue.len(model)
# Active Workers and Processing should read the same. If not, that's an issue.
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {backend_count}')
logger.info(f'Active Workers: {len([i for i in activity if (i[1] and i[1] != "waiting...")])} | Processing: {processing_count} | Queued: {queue_size} | Backends Online: {len(backends)}')
except:
traceback.print_exc()
logger.error(traceback.format_exc())
time.sleep(10)

View File

@ -3,6 +3,7 @@ from threading import Thread
from llm_server import opts
from llm_server.cluster.worker import cluster_worker
from llm_server.logging import create_logger
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.workers.inferencer import start_workers
from llm_server.workers.logger import db_logger
@ -19,36 +20,39 @@ def cache_stats():
def start_background():
logger = create_logger('threader')
start_workers(opts.cluster)
t = Thread(target=main_background_thread)
t.daemon = True
t.start()
print('Started the main background thread.')
logger.info('Started the main background thread.')
start_moderation_workers(opts.cluster_workers * 3)
num_moderators = opts.cluster_workers * 3
start_moderation_workers(num_moderators)
logger.info(f'Started {num_moderators} moderation workers.')
t = Thread(target=cache_stats)
t.daemon = True
t.start()
print('Started the stats cacher.')
logger.info('Started the stats cacher.')
t = Thread(target=recent_prompters_thread)
t.daemon = True
t.start()
print('Started the recent proompters thread.')
logger.info('Started the recent proompters thread.')
t = Thread(target=console_printer)
t.daemon = True
t.start()
print('Started the console printer.')
logger.info('Started the console logger.infoer.')
t = Thread(target=cluster_worker)
t.daemon = True
t.start()
print('Started the cluster worker.')
logger.info('Started the cluster worker.')
t = Thread(target=db_logger)
t.daemon = True
t.start()
print('Started background logger.')
logger.info('Started background logger.')

View File

@ -9,9 +9,10 @@ simplejson~=3.19.1
websockets~=11.0.3
basicauth~=1.0.0
openai~=0.28.0
urllib3~=2.0.4
flask-sock==0.6.0
gunicorn==21.2.0
redis==5.0.1
ujson==5.8.0
vllm
vllm==0.2.1.post1
gradio~=3.46.1
coloredlogs~=15.0.1