From e9f6fdf65e130c7f9e9fbff0346c56a5d136ba0f Mon Sep 17 00:00:00 2001 From: Cyberes Date: Thu, 5 Oct 2023 20:14:28 -0600 Subject: [PATCH] fix streaming? --- llm_server/config/config.py | 1 - llm_server/config/load.py | 5 +- llm_server/opts.py | 4 +- llm_server/routes/openai/chat_completions.py | 126 ++++++++++--------- llm_server/routes/openai/completions.py | 123 +++++++++--------- llm_server/routes/queue.py | 7 -- llm_server/routes/v1/generate_stream.py | 45 +++++-- llm_server/workers/inferencer.py | 50 +++----- 8 files changed, 193 insertions(+), 168 deletions(-) diff --git a/llm_server/config/config.py b/llm_server/config/config.py index 54eb3ec..2c08544 100644 --- a/llm_server/config/config.py +++ b/llm_server/config/config.py @@ -33,7 +33,6 @@ config_default_vars = { 'openai_moderation_enabled': True, 'netdata_root': None, 'show_backends': True, - 'cluster_workers': 30, 'background_homepage_cacher': True, 'openai_moderation_timeout': 5, 'prioritize_by_size': False diff --git a/llm_server/config/load.py b/llm_server/config/load.py index 9a55a70..2847265 100644 --- a/llm_server/config/load.py +++ b/llm_server/config/load.py @@ -45,12 +45,15 @@ def load_config(config_path): opts.openai_silent_trim = config['openai_silent_trim'] opts.openai_moderation_enabled = config['openai_moderation_enabled'] opts.show_backends = config['show_backends'] - opts.cluster_workers = config['cluster_workers'] opts.background_homepage_cacher = config['background_homepage_cacher'] opts.openai_moderation_timeout = config['openai_moderation_timeout'] opts.frontend_api_mode = config['frontend_api_mode'] opts.prioritize_by_size = config['prioritize_by_size'] + # Scale the number of workers. + for item in config['cluster']: + opts.cluster_workers += item['concurrent_gens'] + if opts.openai_expose_our_model and not opts.openai_api_key: print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.') sys.exit(1) diff --git a/llm_server/opts.py b/llm_server/opts.py index 5c32f05..69b25eb 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -34,7 +34,7 @@ openai_silent_trim = False openai_moderation_enabled = True cluster = {} show_backends = True -cluster_workers = 30 background_homepage_cacher = True openai_moderation_timeout = 5 -prioritize_by_size = False \ No newline at end of file +prioritize_by_size = False +cluster_workers = 0 \ No newline at end of file diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index e470a7b..6e1fdf5 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -8,7 +8,7 @@ from llm_server.custom_redis import redis from . import openai_bp from ..helpers.http import validate_json from ..openai_request_handler import OpenAIRequestHandler -from ..queue import decr_active_workers, decrement_ip_count, priority_queue +from ..queue import priority_queue from ... import opts from ...database.log_to_db import log_to_db from ...llm.generator import generator @@ -57,6 +57,7 @@ def openai_chat_completions(): else: handler.prompt = transform_messages_to_prompt(handler.request.json['messages']) + event_id = None response_status_code = 0 start_time = time.time() @@ -70,8 +71,10 @@ def openai_chat_completions(): 'stream': True, } - # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) + event = None + if not handler.is_client_ratelimited(): + # Add a dummy event to the queue and wait for it to reach a worker + event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) if not event: log_to_db( handler.client_ip, @@ -87,8 +90,15 @@ def openai_chat_completions(): ) return handler.handle_ratelimited() - # Wait for a worker to get our request and discard it. - _, _, _ = event.wait() + # Once the worker receives our streaming request, it will tell us we are ready + # to begin inference. + event_id = event.event_id + pubsub = redis.pubsub() + pubsub.subscribe(event_id) + for item in pubsub.listen(): + if item['type'] == 'message' and item['data'].decode('utf-8') == 'begin': + break + time.sleep(0.1) try: r_headers = dict(request.headers) @@ -97,61 +107,63 @@ def openai_chat_completions(): oai_string = generate_oai_string(30) def generate(): - try: - response = generator(msg_to_backend, handler.backend_url) - generated_text = '' - partial_response = b'' - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(handler.prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue + response = generator(msg_to_backend, handler.backend_url) + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(handler.prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue - data = { - "id": f"chatcmpl-{oai_string}", - "object": "chat.completion.chunk", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": new - }, - "finish_reason": None - } - ] - } - yield f'data: {json.dumps(data)}\n\n' - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time - log_to_db( - handler.client_ip, - handler.token, - handler.prompt, - generated_text, - elapsed_time, - handler.parameters, - r_headers, - response_status_code, - r_url, - handler.backend_url, - ) - finally: - # The worker incremented it, we'll decrement it. - decrement_ip_count(handler.client_ip, 'processing_ips') - decr_active_workers(handler.selected_model, handler.backend_url) + data = { + "id": f"chatcmpl-{oai_string}", + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": new + }, + "finish_reason": None + } + ] + } + yield f'data: {json.dumps(data)}\n\n' + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + response_status_code, + r_url, + handler.backend_url, + ) return Response(generate(), mimetype='text/event-stream') except Exception: traceback.print_exc() return 'INTERNAL SERVER', 500 + finally: + # After completing inference, we need to tell the worker we + # are finished. + if event_id: # may be None if ratelimited. + redis.publish(event_id, 'finished') + else: + print('event_id was None!') diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 1843226..9c42cf6 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -8,9 +8,8 @@ from llm_server.custom_redis import redis from . import openai_bp from ..helpers.http import validate_json from ..ooba_request_handler import OobaRequestHandler -from ..queue import decr_active_workers, decrement_ip_count, priority_queue +from ..queue import priority_queue from ... import opts -from ...database.database import do_db_log from ...database.log_to_db import log_to_db from ...llm import get_token_count from ...llm.generator import generator @@ -53,7 +52,6 @@ def openai_completions(): return handler.handle_ratelimited() output = response.json['results'][0]['text'] - # TODO: async/await prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url) response_tokens = get_token_count(output, handler.backend_url) running_model = redis.get('running_model', 'ERROR', dtype=str) @@ -86,6 +84,7 @@ def openai_completions(): if not opts.enable_streaming: return 'DISABLED', 401 + event_id = None response_status_code = 0 start_time = time.time() @@ -100,8 +99,10 @@ def openai_completions(): 'stream': True, } - # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) + event = None + if not handler.is_client_ratelimited(): + # Add a dummy event to the queue and wait for it to reach a worker + event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) if not event: log_to_db( handler.client_ip, @@ -117,8 +118,14 @@ def openai_completions(): ) return handler.handle_ratelimited() - # Wait for a worker to get our request and discard it. - _, _, _ = event.wait() + # Wait for permission to begin. + event_id = event.event_id + pubsub = redis.pubsub() + pubsub.subscribe(event_id) + for item in pubsub.listen(): + if item['type'] == 'message' and item['data'].decode('utf-8') == 'begin': + break + time.sleep(0.1) try: response = generator(msg_to_backend, handler.backend_url) @@ -128,61 +135,61 @@ def openai_completions(): oai_string = generate_oai_string(30) def generate(): - try: - generated_text = '' - partial_response = b'' - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(handler.prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue + generated_text = '' + partial_response = b'' + for chunk in response.iter_content(chunk_size=1): + partial_response += chunk + if partial_response.endswith(b'\x00'): + json_strs = partial_response.split(b'\x00') + for json_str in json_strs: + if json_str: + try: + json_obj = json.loads(json_str.decode()) + new = json_obj['text'][0].split(handler.prompt + generated_text)[1] + generated_text = generated_text + new + except IndexError: + # ???? + continue - data = { - "id": f"cmpl-{oai_string}", - "object": "text_completion", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": new - }, - "finish_reason": None - } - ] - } - yield f'data: {json.dumps(data)}\n\n' - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time + data = { + "id": f"cmpl-{oai_string}", + "object": "text_completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": new + }, + "finish_reason": None + } + ] + } + yield f'data: {json.dumps(data)}\n\n' + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time - log_to_db( - handler.client_ip, - handler.token, - handler.prompt, - generated_text, - elapsed_time, - handler.parameters, - r_headers, - response_status_code, - r_url, - handler.backend_url, - ) - finally: - # The worker incremented it, we'll decrement it. - decrement_ip_count(handler.client_ip, 'processing_ips') - decr_active_workers(handler.selected_model, handler.backend_url) + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + response_status_code, + r_url, + handler.backend_url, + ) return Response(generate(), mimetype='text/event-stream') except Exception: traceback.print_exc() return 'INTERNAL SERVER', 500 + finally: + if event_id: + redis.publish(event_id, 'finished') + else: + print('event_id was None!') diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index f6e7993..b075ead 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -22,8 +22,6 @@ def decrement_ip_count(client_ip: str, redis_key): class RedisPriorityQueue: def __init__(self, name: str = 'priority_queue', db: int = 12): self.redis = RedisCustom(name, db=db) - self.pubsub = self.redis.pubsub() - self.pubsub.subscribe('events') def put(self, item, priority, selected_model): event = DataEvent() @@ -36,8 +34,6 @@ class RedisPriorityQueue: print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.') return None # reject the request - print('--->', event.event_id) - self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority}) self.increment_ip_count(item[1], 'queued_ip_count') return event @@ -54,17 +50,14 @@ class RedisPriorityQueue: def print_all_items(self): items = self.redis.zrange('queue', 0, -1) - print(items) for item in items: print(item.decode('utf-8')) def increment_ip_count(self, client_ip: str, redis_key): new_count = self.redis.hincrby(redis_key, client_ip, 1) - print(client_ip, new_count) def decrement_ip_count(self, client_ip: str, redis_key): new_count = self.redis.hincrby(redis_key, client_ip, -1) - print(client_ip, new_count) if new_count <= 0: self.redis.hdel(redis_key, client_ip) diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 55fceb9..e3818c2 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -7,8 +7,9 @@ from flask import request from . import bp from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler -from ..queue import decr_active_workers, decrement_ip_count, priority_queue +from ..queue import priority_queue from ... import opts +from ...custom_redis import redis from ...database.log_to_db import log_to_db from ...llm.generator import generator from ...sock import sock @@ -94,6 +95,7 @@ def do_stream(ws, model_name): # TODO: implement other backends raise NotImplementedError + event_id = None generated_text = '' input_prompt = request_json_body['prompt'] response_status_code = 0 @@ -117,16 +119,33 @@ def do_stream(ws, model_name): 'stream': True, } - # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) + event = None + if not handler.is_client_ratelimited(): + # Add a dummy event to the queue and wait for it to reach a worker + event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model) if not event: - r, _ = handler.handle_ratelimited() - err_msg = r.json['results'][0]['text'] - send_err_and_quit(err_msg) - return + log_to_db( + handler.client_ip, + handler.token, + handler.request_json_body.get('prompt'), + None, + None, + handler.parameters, + request.headers, + response_status_code, + request.url, + handler.backend_url, + ) + return handler.handle_ratelimited() - # Wait for a worker to get our request and discard it. - _, _, _ = event.wait() + # Wait for permission to begin. + event_id = event.event_id + pubsub = redis.pubsub() + pubsub.subscribe(event_id) + for item in pubsub.listen(): + if item['type'] == 'message' and item['data'].decode('utf-8') == 'begin': + break + time.sleep(0.1) try: response = generator(llm_request, handler.backend_url) @@ -195,9 +214,11 @@ def do_stream(ws, model_name): })) # used to log here finally: - # The worker incremented it, we'll decrement it. - decrement_ip_count(handler.client_ip, 'processing_ips') - decr_active_workers(handler.selected_model, handler.backend_url) + if event_id: + redis.publish(event_id, 'finished') + else: + print('event_id was None!') + try: ws.send(json.dumps({ 'event': 'stream_end', diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index 178bfd6..a545ae6 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -19,27 +19,30 @@ def worker(): if not selected_model: selected_model = backend_info['model'] - # This wait time will be "invisible", meaning the worker may as - # well be still waiting to get an item from the queue. - need_to_wait(backend_url) - increment_ip_count(client_ip, 'processing_ips') incr_active_workers(selected_model, backend_url) - print('<---', event_id) - - if not request_json_body: - # This was a dummy request from the websocket handlers. - # We're going to let the websocket handler decrement - # processing_ips and active_gen_workers. - event = DataEvent(event_id) - event.set((True, None, None)) - continue - try: - success, response, error_msg = generator(request_json_body, backend_url) - event = DataEvent(event_id) - event.set((success, response, error_msg)) + if not request_json_body: + # This was a dummy request from the streaming handlers. + # The worker will let the handler do the streaming instead + # of the worker. The worker will block until the handler + # is finished. Since a lot of ratelimiting and stats are + # based off the number of active workers, we must keep + # the generation based off the workers. + pubsub = redis.pubsub() + pubsub.subscribe(event_id) + redis.publish(event_id, 'begin') + for item in pubsub.listen(): + if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished': + # Once the handler is complete, move on. + break + time.sleep(0.1) + else: + # Normal inference (not streaming). + success, response, error_msg = generator(request_json_body, backend_url) + event = DataEvent(event_id) + event.set((success, response, error_msg)) finally: decrement_ip_count(client_ip, 'processing_ips') decr_active_workers(selected_model, backend_url) @@ -53,16 +56,3 @@ def start_workers(num_workers: int): t.start() i += 1 print(f'Started {i} inference workers.') - - -def need_to_wait(backend_url: str): - # We need to check the number of active workers since the streaming endpoint may be doing something. - active_workers = redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int) - concurrent_gens = cluster_config.get_backend(backend_url).get('concurrent_gens', 1) - s = time.time() - print(active_workers) - while active_workers >= concurrent_gens: - time.sleep(0.01) - e = time.time() - if e - s > 0.1: - print(f'Worker was delayed {e - s} seconds.')