From 31ab4188f1e8bb1329fc40ae4fd9e0a0b132eeac Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sun, 15 Oct 2023 20:45:01 -0600 Subject: [PATCH] fix issues with queue and streaming --- llm_server/custom_redis.py | 4 +- llm_server/llm/oobabooga/ooba_backend.py | 77 +------------------- llm_server/llm/vllm/tokenize.py | 2 +- llm_server/routes/ooba_request_handler.py | 2 +- llm_server/routes/openai/chat_completions.py | 14 +++- llm_server/routes/queue.py | 45 ++++++++++-- llm_server/workers/inferencer.py | 61 ++++++++++++++-- llm_server/workers/mainer.py | 6 ++ llm_server/workers/printer.py | 6 +- 9 files changed, 119 insertions(+), 98 deletions(-) diff --git a/llm_server/custom_redis.py b/llm_server/custom_redis.py index 485cb58..aacaec0 100644 --- a/llm_server/custom_redis.py +++ b/llm_server/custom_redis.py @@ -200,10 +200,10 @@ class RedisCustom(Redis): return json.loads(r.decode("utf-8")) def setp(self, name, value): - self.redis.set(name, pickle.dumps(value)) + self.redis.set(self._key(name), pickle.dumps(value)) def getp(self, name: str): - r = self.redis.get(name) + r = self.redis.get(self._key(name)) if r: return pickle.loads(r) return r diff --git a/llm_server/llm/oobabooga/ooba_backend.py b/llm_server/llm/oobabooga/ooba_backend.py index 0e2b2d8..18fe6b1 100644 --- a/llm_server/llm/oobabooga/ooba_backend.py +++ b/llm_server/llm/oobabooga/ooba_backend.py @@ -1,79 +1,6 @@ -from flask import jsonify - -from llm_server.custom_redis import redis from ..llm_backend import LLMBackend -from ...database.database import do_db_log -from ...helpers import safe_list_get -from ...routes.helpers.client import format_sillytavern_err -from ...routes.helpers.http import validate_json class OobaboogaBackend(LLMBackend): - default_params = {} - - def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): - raise NotImplementedError('need to implement default_params') - - backend_err = False - response_valid_json, response_json_body = validate_json(response) - if response: - try: - # Be extra careful when getting attributes from the response object - response_status_code = response.status_code - except: - response_status_code = 0 - else: - response_status_code = None - - # =============================================== - - # We encountered an error - if not success or not response or error_msg: - if not error_msg or error_msg == '': - error_msg = 'Unknown error.' - else: - error_msg = error_msg.strip('.') + '.' - backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url) - log_to_db(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True) - return jsonify({ - 'code': 500, - 'msg': error_msg, - 'results': [{'text': backend_response}] - }), 400 - - # =============================================== - - if response_valid_json: - backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') - if not backend_response: - # Ooba doesn't return any error messages so we will just tell the client an error occurred - backend_err = True - backend_response = format_sillytavern_err( - f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.', - error_type='error', - backend_url=self.backend_url) - response_json_body['results'][0]['text'] = backend_response - - if not backend_err: - redis.incr('proompts') - - log_to_db(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) - return jsonify({ - **response_json_body - }), 200 - else: - backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url) - log_to_db(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True) - return jsonify({ - 'code': 500, - 'msg': 'the backend did not return valid JSON', - 'results': [{'text': backend_response}] - }), 400 - - def validate_params(self, params_dict: dict): - # No validation required - return True, None - - def get_parameters(self, parameters): - del parameters['prompt'] - return parameters + def __int__(self): + return diff --git a/llm_server/llm/vllm/tokenize.py b/llm_server/llm/vllm/tokenize.py index 8b18073..69a2b14 100644 --- a/llm_server/llm/vllm/tokenize.py +++ b/llm_server/llm/vllm/tokenize.py @@ -33,7 +33,7 @@ def tokenize(prompt: str, backend_url: str) -> int: j = r.json() return j['length'] except Exception as e: - print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}') + print(f'Failed to tokenize using VLLM - {e.__class__.__name__}') return len(tokenizer.encode(chunk)) + 10 # Use a ThreadPoolExecutor to send all chunks to the server at once diff --git a/llm_server/routes/ooba_request_handler.py b/llm_server/routes/ooba_request_handler.py index 6966e32..804be74 100644 --- a/llm_server/routes/ooba_request_handler.py +++ b/llm_server/routes/ooba_request_handler.py @@ -17,7 +17,7 @@ class OobaRequestHandler(RequestHandler): assert not self.used if self.offline: print(messages.BACKEND_OFFLINE) - self.handle_error(messages.BACKEND_OFFLINE) + return self.handle_error(messages.BACKEND_OFFLINE) request_valid, invalid_response = self.validate_request() if not request_valid: diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index 426fd98..87a7330 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -79,6 +79,7 @@ def openai_chat_completions(model_name=None): event = None if not handler.is_client_ratelimited(): + start_time = time.time() # Add a dummy event to the queue and wait for it to reach a worker event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model) if not event: @@ -102,11 +103,14 @@ def openai_chat_completions(model_name=None): pubsub = redis.pubsub() pubsub.subscribe(event_id) for item in pubsub.listen(): + if time.time() - start_time >= opts.backend_generate_request_timeout: + raise Exception('Inferencer timed out waiting for streaming to complete:', request_json_body) if item['type'] == 'message': msg = item['data'].decode('utf-8') if msg == 'begin': break elif msg == 'offline': + # This shouldn't happen because the best model should be auto-selected. return return_invalid_model_err(handler.request_json_body['model']) time.sleep(0.1) @@ -135,6 +139,7 @@ def openai_chat_completions(model_name=None): json_obj = json.loads(json_str.decode()) new = json_obj['text'][0].split(handler.prompt + generated_text)[1] generated_text = generated_text + new + redis.publish(event_id, 'chunk') # Keepalive except IndexError: # ???? continue @@ -170,9 +175,14 @@ def openai_chat_completions(model_name=None): r_url, handler.backend_url, ) + except GeneratorExit: + yield 'data: [DONE]\n\n' + except: + # AttributeError: 'bool' object has no attribute 'iter_content' + traceback.print_exc() + yield 'data: [DONE]\n\n' finally: - # After completing inference, we need to tell the worker we - # are finished. + # After completing inference, we need to tell the worker we are finished. if event_id: # may be None if ratelimited. redis.publish(event_id, 'finished') else: diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index 24bc019..3e4279f 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -6,6 +6,7 @@ from uuid import uuid4 from redis import Redis +from llm_server import opts from llm_server.cluster.cluster_config import cluster_config from llm_server.custom_redis import RedisCustom, redis from llm_server.database.database import get_token_ratelimit @@ -23,9 +24,14 @@ def decrement_ip_count(client_ip: str, redis_key): class RedisPriorityQueue: def __init__(self, name, db: int = 12): + self.name = name self.redis = RedisCustom(name, db=db) def put(self, item, priority, selected_model): + assert item is not None + assert priority is not None + assert selected_model is not None + event = DataEvent() # Check if the IP is already in the dictionary and if it has reached the limit ip_count = self.redis.hget('queued_ip_count', item[1]) @@ -36,7 +42,8 @@ class RedisPriorityQueue: print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') return None # reject the request - self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority}) + timestamp = time.time() + self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp)): -priority}) self.increment_ip_count(item[1], 'queued_ip_count') return event @@ -52,11 +59,13 @@ class RedisPriorityQueue: def print_all_items(self): items = self.redis.zrange('queue', 0, -1) + to_print = [] for item in items: - print(item.decode('utf-8')) + to_print.append(item.decode('utf-8')) + print(f'ITEMS {self.name} -->', to_print) def increment_ip_count(self, client_ip: str, redis_key): - new_count = self.redis.hincrby(redis_key, client_ip, 1) + self.redis.hincrby(redis_key, client_ip, 1) def decrement_ip_count(self, client_ip: str, redis_key): new_count = self.redis.hincrby(redis_key, client_ip, -1) @@ -75,6 +84,16 @@ class RedisPriorityQueue: def flush(self): self.redis.flush() + def cleanup(self): + now = time.time() + items = self.redis.zrange('queue', 0, -1) + for item in items: + item_data = json.loads(item) + timestamp = item_data[-1] + if now - timestamp > opts.backend_generate_request_timeout * 3: # TODO: config option + self.redis.zrem('queue', item) + print('removed item from queue:', item) + class DataEvent: def __init__(self, event_id=None): @@ -112,7 +131,7 @@ def decr_active_workers(selected_model: str, backend_url: str): class PriorityQueue: - def __init__(self, backends: list = None): + def __init__(self, backends: set = None): """ Only have to load the backends once. :param backends: @@ -120,10 +139,10 @@ class PriorityQueue: self.redis = Redis(host='localhost', port=6379, db=9) if backends: for item in backends: - self.redis.lpush('backends', item) + self.redis.sadd('backends', item) def get_backends(self): - return [x.decode('utf-8') for x in self.redis.lrange('backends', 0, -1)] + return {x.decode('utf-8') for x in self.redis.smembers('backends')} def get_queued_ip_count(self, client_ip: str): count = 0 @@ -136,22 +155,32 @@ class PriorityQueue: queue = RedisPriorityQueue(backend_url) return queue.put(item, priority, selected_model) + def activity(self): + lines = [] + status_redis = RedisCustom('worker_status') + for worker in status_redis.keys(): + lines.append((worker, status_redis.getp(worker))) + return sorted(lines) + def len(self, model_name): count = 0 - backends_with_models = [] + backends_with_models = set() for k in self.get_backends(): info = cluster_config.get_backend(k) if info.get('model') == model_name: - backends_with_models.append(k) + backends_with_models.add(k) for backend_url in backends_with_models: count += len(RedisPriorityQueue(backend_url)) return count def __len__(self): count = 0 + p = set() for backend_url in self.get_backends(): queue = RedisPriorityQueue(backend_url) + p.add((backend_url, len(queue))) count += len(queue) + print(p) return count def flush(self): diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index d65d125..f41d0d3 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -1,20 +1,48 @@ +import queue import threading import time import traceback +from uuid import uuid4 +from redis.client import PubSub + +from llm_server import opts from llm_server.cluster.cluster_config import cluster_config -from llm_server.custom_redis import redis +from llm_server.custom_redis import RedisCustom, redis from llm_server.llm.generator import generator from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count +class ListenerThread(threading.Thread): + def __init__(self, pubsub: PubSub, listener_queue: queue.Queue, stop_event: threading.Event): + threading.Thread.__init__(self) + self.pubsub = pubsub + self.listener_queue = listener_queue + self.stop_event = stop_event + + def run(self): + while not self.stop_event.is_set(): + message = self.pubsub.get_message() + if message: + self.listener_queue.put(message) + time.sleep(0.1) + + def worker(backend_url): - queue = RedisPriorityQueue(backend_url) + status_redis = RedisCustom('worker_status') + worker_id = uuid4() + status_redis.setp(str(worker_id), None) + redis_queue = RedisPriorityQueue(backend_url) while True: - (request_json_body, client_ip, token, parameters), event_id, selected_model = queue.get() + (request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp = redis_queue.get() backend_info = cluster_config.get_backend(backend_url) + pubsub = redis.pubsub() pubsub.subscribe(event_id) + stop_event = threading.Event() + q = queue.Queue() + listener = ListenerThread(pubsub, q, stop_event) + listener.start() if not backend_info['online']: redis.publish(event_id, 'offline') @@ -26,6 +54,8 @@ def worker(backend_url): increment_ip_count(client_ip, 'processing_ips') incr_active_workers(selected_model, backend_url) + status_redis.setp(str(worker_id), (backend_url, client_ip)) + try: if not request_json_body: # This was a dummy request from the streaming handlers. @@ -34,13 +64,27 @@ def worker(backend_url): # is finished. Since a lot of ratelimiting and stats are # based off the number of active workers, we must keep # the generation based off the workers. + start_time = time.time() redis.publish(event_id, 'begin') - for item in pubsub.listen(): - if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished': - # The streaming endpoint has said that it has finished + while True: + status_redis.setp(str(worker_id), (f'waiting for streaming to complete - {time.time() - start_time} - {opts.backend_generate_request_timeout}', client_ip)) + + try: + item = q.get(timeout=30) + except queue.Empty: + print('Inferencer timed out waiting for chunk from streamer:', (request_json_body, client_ip, token, parameters), event_id, selected_model) + status_redis.setp(str(worker_id), ('streaming chunk timed out', client_ip)) + break + + if time.time() - start_time >= opts.backend_generate_request_timeout: + status_redis.setp(str(worker_id), ('streaming timed out', client_ip)) + print('Inferencer timed out waiting for streaming to complete:', (request_json_body, client_ip, token, parameters), event_id, selected_model) + break + if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished': + status_redis.setp(str(worker_id), ('streaming completed', client_ip)) break - time.sleep(0.1) else: + status_redis.setp(str(worker_id), ('generating', client_ip)) # Normal inference (not streaming). success, response, error_msg = generator(request_json_body, backend_url) event = DataEvent(event_id) @@ -48,8 +92,11 @@ def worker(backend_url): except: traceback.print_exc() finally: + stop_event.set() # make sure to stop the listener thread + listener.join() decrement_ip_count(client_ip, 'processing_ips') decr_active_workers(selected_model, backend_url) + status_redis.setp(str(worker_id), None) def start_workers(cluster: dict): diff --git a/llm_server/workers/mainer.py b/llm_server/workers/mainer.py index 37c1178..d342f4b 100644 --- a/llm_server/workers/mainer.py +++ b/llm_server/workers/mainer.py @@ -7,6 +7,7 @@ from llm_server.cluster.cluster_config import cluster_config, get_backends from llm_server.custom_redis import redis from llm_server.database.database import weighted_average_column_for_model from llm_server.llm.info import get_info +from llm_server.routes.queue import RedisPriorityQueue, priority_queue def main_background_thread(): @@ -35,6 +36,11 @@ def main_background_thread(): except Exception as e: print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}') + backends = priority_queue.get_backends() + for backend_url in backends: + queue = RedisPriorityQueue(backend_url) + queue.cleanup() + time.sleep(30) diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py index ed6ff65..a3da690 100644 --- a/llm_server/workers/printer.py +++ b/llm_server/workers/printer.py @@ -24,5 +24,7 @@ def console_printer(): for k in processing: processing_count += redis.get(k, default=0, dtype=int) backends = [k for k, v in cluster_config.all().items() if v['online']] - logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}') - time.sleep(10) + activity = priority_queue.activity() + print(activity) + logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}') + time.sleep(1)