From 2c7773cc4f147e1f9a07522809291def37c641ee Mon Sep 17 00:00:00 2001 From: Cyberes Date: Mon, 16 Oct 2023 16:22:52 -0600 Subject: [PATCH] get streaming working again --- llm_server/custom_redis.py | 9 + llm_server/database/log_to_db.py | 3 + llm_server/netdata.py | 52 ----- llm_server/routes/openai/chat_completions.py | 187 ++++++++-------- llm_server/routes/openai/completions.py | 221 ++++++++++--------- llm_server/routes/openai_request_handler.py | 4 - llm_server/routes/request_handler.py | 3 + llm_server/routes/v1/generate_stream.py | 155 ++++--------- llm_server/sock.py | 2 +- llm_server/workers/inferencer.py | 18 +- llm_server/workers/printer.py | 2 +- other/ooba-test-streaming.py | 4 +- server.py | 9 +- 13 files changed, 296 insertions(+), 373 deletions(-) delete mode 100644 llm_server/netdata.py diff --git a/llm_server/custom_redis.py b/llm_server/custom_redis.py index aacaec0..60e4dbd 100644 --- a/llm_server/custom_redis.py +++ b/llm_server/custom_redis.py @@ -223,5 +223,14 @@ class RedisCustom(Redis): self.flush() return True + def lrange(self, name: str, start: int, end: int): + return self.redis.lrange(self._key(name), start, end) + + def delete(self, *names: KeyT): + return self.redis.delete(*[self._key(i) for i in names]) + + def lpop(self, name: str, count: Optional[int] = None): + return self.redis.lpop(self._key(name), count) + redis = RedisCustom('local_llm') diff --git a/llm_server/database/log_to_db.py b/llm_server/database/log_to_db.py index be6946f..75bcaab 100644 --- a/llm_server/database/log_to_db.py +++ b/llm_server/database/log_to_db.py @@ -5,6 +5,9 @@ from redis import Redis def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False): + assert isinstance(prompt, str) + assert isinstance(backend_url, str) + r = Redis(host='localhost', port=6379, db=3) data = { 'function': 'log_prompt', diff --git a/llm_server/netdata.py b/llm_server/netdata.py deleted file mode 100644 index f37c109..0000000 --- a/llm_server/netdata.py +++ /dev/null @@ -1,52 +0,0 @@ -import json -from datetime import datetime, timedelta - -import requests - -from llm_server import opts - - -def get_power_states(): - gpu_num = 0 - output = {} - while True: - url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state" - try: - response = requests.get(url, timeout=10) - if response.status_code != 200: - break - data = json.loads(response.text) - power_state_data = data['data'][0] - power_state = None - for i in range(1, len(power_state_data)): - if power_state_data[i] == 1: - power_state = data['labels'][i] - break - output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p')) - except Exception as e: - print('Failed to fetch Netdata metrics:', e) - return output - gpu_num += 1 - return output - - -def get_gpu_wh(gpu_id: int): - chart_name = f"nvidia_smi.gpu{gpu_id}_power" - now = datetime.now() - one_hour_ago = now - timedelta(hours=1) - num_seconds = int((now - one_hour_ago).total_seconds()) - params = { - "chart": chart_name, - "after": int(one_hour_ago.timestamp()), - "before": int(now.timestamp()), - "points": num_seconds, - "group": "second", - "format": "json", - "options": "absolute|jsonwrap" - } - response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10) - data = json.loads(response.text) - total_power_usage_watts = sum(point[1] for point in data['result']['data']) - # total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1) - total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3) - return total_power_usage_kwh diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index a840070..44d5172 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -43,24 +43,23 @@ def openai_chat_completions(model_name=None): if not opts.enable_streaming: return 'Streaming disabled', 403 - handler.parameters, _ = handler.get_parameters() - handler.request_json_body = { - 'messages': handler.request_json_body['messages'], - 'model': handler.request_json_body['model'], - **handler.parameters - } - invalid_oai_err_msg = validate_oai(handler.request_json_body) if invalid_oai_err_msg: return invalid_oai_err_msg handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode']) + handler.parameters, e = handler.get_parameters() + handler.request_json_body = { + 'messages': handler.request_json_body['messages'], + 'model': handler.request_json_body['model'], + **handler.parameters + } + if opts.openai_silent_trim: handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)) else: handler.prompt = transform_messages_to_prompt(handler.request.json['messages']) - if not handler.prompt: # Prevent issues on the backend. return 'Invalid prompt', 400 @@ -73,90 +72,94 @@ def openai_chat_completions(model_name=None): request_valid, invalid_response = handler.validate_request() if not request_valid: return invalid_response - else: - event = None - if not handler.is_client_ratelimited(): - event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) - if not event: - log_to_db( - handler.client_ip, - handler.token, - handler.prompt, - None, - None, - handler.parameters, - request.headers, - 429, - request.url, - handler.backend_url, - ) - return handler.handle_ratelimited() - try: - r_headers = dict(request.headers) - r_url = request.url - model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') - oai_string = generate_oai_string(30) + event = None + if not handler.is_client_ratelimited(): + event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) + if not event: + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + None, + None, + handler.parameters, + request.headers, + 429, + request.url, + handler.backend_url, + ) + return handler.handle_ratelimited() - def generate(): - stream_name = event.wait() - stream_redis = Redis(db=8) - generated_text = '' - try: - while True: - stream_data = stream_redis.xread({stream_name: '0-0'}, block=30000) - if not stream_data: - print("No message received in 30 seconds, closing stream.") - yield 'data: [DONE]\n\n' - else: - for r_timestamp, item in stream_data[0][1]: - timestamp = int(r_timestamp.decode('utf-8').split('-')[0]) - data = pickle.loads(item[b'data']) - if data['error']: - yield 'data: [DONE]\n\n' - return - elif data['new']: - response = { - "id": f"chatcmpl-{oai_string}", - "object": "chat.completion.chunk", - "created": timestamp, - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": data['new'] - }, - "finish_reason": None - } - ] - } - generated_text = generated_text + data['new'] - yield f'data: {json.dumps(response)}\n\n' - elif data['completed']: - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time - log_to_db( - handler.client_ip, - handler.token, - handler.prompt, - generated_text, - elapsed_time, - handler.parameters, - r_headers, - 200, - r_url, - handler.backend_url, - ) - return - except (Exception, GeneratorExit): - traceback.print_exc() - yield 'data: [DONE]\n\n' - finally: - stream_redis.delete(stream_name) + try: + r_headers = dict(request.headers) + r_url = request.url + model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') + oai_string = generate_oai_string(30) - return Response(generate(), mimetype='text/event-stream') - except Exception: - traceback.print_exc() - return 'INTERNAL SERVER', 500 + def generate(): + stream_name = event.wait() + stream_redis = Redis(db=8) + generated_text = '' + try: + last_id = '0-0' + while True: + stream_data = stream_redis.xread({stream_name: last_id}, block=30000) + if not stream_data: + print("No message received in 30 seconds, closing stream.") + yield 'data: [DONE]\n\n' + else: + for stream_index, item in stream_data[0][1]: + last_id = stream_index + timestamp = int(stream_index.decode('utf-8').split('-')[0]) + data = pickle.loads(item[b'data']) + if data['error']: + yield 'data: [DONE]\n\n' + return + elif data['new']: + response = { + "id": f"chatcmpl-{oai_string}", + "object": "chat.completion.chunk", + "created": timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": data['new'] + }, + "finish_reason": None + } + ] + } + generated_text = generated_text + data['new'] + yield f'data: {json.dumps(response)}\n\n' + elif data['completed']: + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + 200, + r_url, + handler.backend_url, + ) + return + except (Exception, GeneratorExit): + traceback.print_exc() + yield 'data: [DONE]\n\n' + finally: + if event: + redis.lpush(f'notifications:{event.event_id}', 'canceled') + stream_redis.delete(stream_name) + + return Response(generate(), mimetype='text/event-stream') + except Exception: + traceback.print_exc() + return 'INTERNAL SERVER', 500 diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index dc7f9e6..374cdc2 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -1,8 +1,10 @@ +import pickle import time import traceback import simplejson as json from flask import Response, jsonify, request +from redis import Redis from llm_server.custom_redis import redis from . import openai_bp, openai_model_bp @@ -12,7 +14,6 @@ from ..queue import priority_queue from ... import opts from ...database.log_to_db import log_to_db from ...llm import get_token_count -from ...llm.generator import generator from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.transform import generate_oai_string, trim_string_to_fit @@ -42,12 +43,14 @@ def openai_completions(model_name=None): handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode']) if opts.openai_silent_trim: - handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) + handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) else: # The handle_request() call below will load the prompt so we don't have # to do anything else here. pass + handler.request_json_body['prompt'] = handler.prompt + if not request_json_body.get('stream'): invalid_oai_err_msg = validate_oai(request_json_body) if invalid_oai_err_msg: @@ -89,120 +92,120 @@ def openai_completions(model_name=None): if not opts.enable_streaming: return 'Streaming disabled', 403 - event_id = None + request_valid, invalid_response = handler.validate_request() + if not request_valid: + return invalid_response + + handler.parameters, _ = handler.get_parameters() + handler.request_json_body = { + 'prompt': handler.request_json_body['prompt'], + 'model': handler.request_json_body['model'], + **handler.parameters + } + + invalid_oai_err_msg = validate_oai(handler.request_json_body) + if invalid_oai_err_msg: + return invalid_oai_err_msg + + if opts.openai_silent_trim: + handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']] + if not handler.prompt: + # Prevent issues on the backend. + return 'Invalid prompt', 400 + start_time = time.time() request_valid, invalid_response = handler.validate_request() if not request_valid: return invalid_response - else: - handler.prompt = handler.request_json_body['prompt'] - msg_to_backend = { - **handler.parameters, - 'prompt': handler.prompt, - 'stream': True, - } - event = None - if not handler.is_client_ratelimited(): - # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model) - if not event: - log_to_db( - handler.client_ip, - handler.token, - handler.prompt, - None, - None, - handler.parameters, - request.headers, - 429, - request.url, - handler.backend_url, - ) - return handler.handle_ratelimited() + event = None + if not handler.is_client_ratelimited(): + event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) + if not event: + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + None, + None, + handler.parameters, + request.headers, + 429, + request.url, + handler.backend_url, + ) + return handler.handle_ratelimited() - # Wait for permission to begin. - event_id = event.event_id - pubsub = redis.pubsub() - pubsub.subscribe(event_id) - for item in pubsub.listen(): - if item['type'] == 'message': - msg = item['data'].decode('utf-8') - if msg == 'begin': - break - elif msg == 'offline': - return return_invalid_model_err(handler.request_json_body['model']) - time.sleep(0.1) + try: + r_headers = dict(request.headers) + r_url = request.url + model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') + oai_string = generate_oai_string(30) - # Double check the model is still online - if not handler.check_online(): - return return_invalid_model_err(handler.request_json_body['model']) - - try: - response = generator(msg_to_backend, handler.backend_url) - r_headers = dict(request.headers) - r_url = request.url - model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') - oai_string = generate_oai_string(30) - - def generate(): - try: - generated_text = '' - partial_response = b'' - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(handler.prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue - - data = { - "id": f"cmpl-{oai_string}", - "object": "text_completion", - "created": int(time.time()), - "model": model, - "choices": [ - { - "index": 0, - "delta": { - "content": new - }, - "finish_reason": None - } - ] - } - yield f'data: {json.dumps(data)}\n\n' - yield 'data: [DONE]\n\n' - end_time = time.time() - elapsed_time = end_time - start_time - - log_to_db( - handler.client_ip, - handler.token, - handler.prompt, - generated_text, - elapsed_time, - handler.parameters, - r_headers, - 200, - r_url, - handler.backend_url, - ) - finally: - if event_id: - redis.publish(event_id, 'finished') + def generate(): + stream_name = event.wait() + stream_redis = Redis(db=8) + generated_text = '' + try: + last_id = '0-0' + while True: + stream_data = stream_redis.xread({stream_name: last_id}, block=30000) + if not stream_data: + print("No message received in 30 seconds, closing stream.") + yield 'data: [DONE]\n\n' else: - print('event_id was None!') + for stream_index, item in stream_data[0][1]: + last_id = stream_index + timestamp = int(stream_index.decode('utf-8').split('-')[0]) + data = pickle.loads(item[b'data']) + if data['error']: + yield 'data: [DONE]\n\n' + return + elif data['new']: + response = { + "id": f"cmpl-{oai_string}", + "object": "text_completion", + "created": timestamp, + "model": model, + "choices": [ + { + "index": 0, + "delta": { + "content": data['new'] + }, + "finish_reason": None + } + ] + } + generated_text = generated_text + data['new'] + yield f'data: {json.dumps(response)}\n\n' + elif data['completed']: + yield 'data: [DONE]\n\n' + end_time = time.time() + elapsed_time = end_time - start_time + log_to_db( + handler.client_ip, + handler.token, + handler.prompt, + generated_text, + elapsed_time, + handler.parameters, + r_headers, + 200, + r_url, + handler.backend_url, + ) + return + except (Exception, GeneratorExit): + traceback.print_exc() + yield 'data: [DONE]\n\n' + finally: + if event: + redis.lpush(f'notifications:{event.event_id}', 'canceled') + stream_redis.delete(stream_name) - return Response(generate(), mimetype='text/event-stream') - except Exception: - traceback.print_exc() - return 'INTERNAL SERVER', 500 + return Response(generate(), mimetype='text/event-stream') + except Exception: + traceback.print_exc() + return 'INTERNAL SERVER', 500 diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index 549cc93..9cbb11c 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -150,10 +150,6 @@ class OpenAIRequestHandler(RequestHandler): "total_tokens": prompt_tokens + response_tokens } }), 200) - - stats = redis.get('proxy_stats', dtype=dict) - if stats: - response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec'] return response def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]: diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 4011030..0fe94ec 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -37,6 +37,9 @@ class RequestHandler: self.parameters = None self.used = False + # This is null by default since most handlers need to transform the prompt in a specific way. + self.prompt = None + self.selected_model = selected_model self.backend_url = get_a_cluster_backend(selected_model) self.cluster_backend_info = cluster_config.get_backend(self.backend_url) diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 332fe4c..7c02cc9 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -1,17 +1,18 @@ import json +import pickle import time import traceback from flask import request +from redis import Redis from . import bp from ..helpers.http import require_api_key, validate_json from ..ooba_request_handler import OobaRequestHandler from ..queue import priority_queue -from ... import messages, opts +from ... import opts from ...custom_redis import redis from ...database.log_to_db import log_to_db -from ...llm.generator import generator from ...sock import sock @@ -35,6 +36,7 @@ def stream_with_model(ws, model_name=None): def do_stream(ws, model_name): + event_id = None try: def send_err_and_quit(quitting_err_msg): ws.send(json.dumps({ @@ -46,6 +48,7 @@ def do_stream(ws, model_name): 'event': 'stream_end', 'message_num': 1 })) + ws.close() log_to_db(ip=handler.client_ip, token=handler.token, prompt=input_prompt, @@ -55,7 +58,7 @@ def do_stream(ws, model_name): headers=r_headers, backend_response_code=response_status_code, request_url=r_url, - backend_url=handler.cluster_backend_info, + backend_url=handler.backend_url, response_tokens=None, is_error=True ) @@ -74,6 +77,7 @@ def do_stream(ws, model_name): if not request_valid_json or not request_json_body.get('prompt'): return 'Invalid JSON', 400 else: + # We have to do auth ourselves since the details are sent in the message. auth_failure = require_api_key(request_json_body) if auth_failure: return auth_failure @@ -89,14 +93,10 @@ def do_stream(ws, model_name): })) return - assert not handler.offline - if handler.cluster_backend_info['mode'] != 'vllm': # TODO: implement other backends raise NotImplementedError - event_id = None - generated_text = '' input_prompt = request_json_body['prompt'] response_status_code = 0 start_time = time.time() @@ -113,119 +113,55 @@ def do_stream(ws, model_name): send_err_and_quit(err_msg) return - llm_request = { - **handler.parameters, - 'prompt': input_prompt, - 'stream': True, + handler.parameters, _ = handler.get_parameters() + handler.prompt = input_prompt + handler.request_json_body = { + 'prompt': handler.prompt, + **handler.parameters } event = None if not handler.is_client_ratelimited(): - # Add a dummy event to the queue and wait for it to reach a worker - event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model) + event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) if not event: - log_to_db( - handler.client_ip, - handler.token, - handler.request_json_body.get('prompt'), - None, - None, - handler.parameters, - request.headers, - response_status_code, - request.url, - handler.backend_url, - ) - return handler.handle_ratelimited() - - # Wait for permission to begin. + r = handler.handle_ratelimited() + send_err_and_quit(r[0].data) + return event_id = event.event_id - pubsub = redis.pubsub() - pubsub.subscribe(event_id) - for item in pubsub.listen(): - if item['type'] == 'message': - msg = item['data'].decode('utf-8') - if msg == 'begin': - break - elif msg == 'offline': - return messages.BACKEND_OFFLINE, 404 # TODO: format this error - time.sleep(0.1) - # Double check the model is still online - if not handler.check_online(): - return messages.BACKEND_OFFLINE, 404 # TODO: format this error + stream_name = event.wait() + stream_redis = Redis(db=8) + generated_text = '' try: - response = generator(llm_request, handler.backend_url) - if not response: - error_msg = 'Failed to reach backend while streaming.' - print('Streaming failed:', error_msg) - msg = handler.handle_error(error_msg)[0].json['results'][0]['text'] - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': msg - })) - else: - # Be extra careful when getting attributes from the response object - try: - response_status_code = response.status_code - except: - response_status_code = 0 - - partial_response = b'' - - for chunk in response.iter_content(chunk_size=1): - partial_response += chunk - if partial_response.endswith(b'\x00'): - json_strs = partial_response.split(b'\x00') - for json_str in json_strs: - if json_str: - try: - json_obj = json.loads(json_str.decode()) - new = json_obj['text'][0].split(input_prompt + generated_text)[1] - generated_text = generated_text + new - except IndexError: - # ???? - continue - try: - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': new - })) - except: - # The has client closed the stream. - if response: - # Cancel the backend? - response.close() - # used to log here - return - - message_num += 1 - partial_response = b'' # Reset the partial response - - # If there is no more data, break the loop - if not chunk: - break - if response: - response.close() - # used to log here + last_id = '0-0' # The ID of the last entry we read. + while True: + stream_data = stream_redis.xread({stream_name: last_id}, block=30000) + if not stream_data: + print("No message received in 30 seconds, closing stream.") + return + else: + for stream_index, item in stream_data[0][1]: + last_id = stream_index + data = pickle.loads(item[b'data']) + if data['error']: + print(data['error']) + send_err_and_quit('Encountered exception while streaming.') + return + elif data['new']: + ws.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': data['new'] + })) + message_num += 1 + generated_text = generated_text + data['new'] + elif data['completed']: + return except: + send_err_and_quit('Encountered exception while streaming.') traceback.print_exc() - generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text'] - ws.send(json.dumps({ - 'event': 'text_stream', - 'message_num': message_num, - 'text': generated_text - })) - # used to log here finally: - if event_id: - redis.publish(event_id, 'finished') - else: - print('event_id was None!') - try: ws.send(json.dumps({ 'event': 'stream_end', @@ -234,6 +170,7 @@ def do_stream(ws, model_name): except: # The client closed the stream. pass + stream_redis.delete(stream_name) end_time = time.time() elapsed_time = end_time - start_time log_to_db(ip=handler.client_ip, @@ -248,6 +185,8 @@ def do_stream(ws, model_name): backend_url=handler.backend_url ) finally: + if event_id: + redis.lpush(f'notifications:{event_id}', 'canceled') try: # Must close the connection or greenlets will complain. ws.close() diff --git a/llm_server/sock.py b/llm_server/sock.py index 8ac2fc1..2f1a17d 100644 --- a/llm_server/sock.py +++ b/llm_server/sock.py @@ -3,6 +3,6 @@ from flask_sock import Sock sock = Sock() -def init_socketio(app): +def init_wssocket(app): global sock sock.init_app(app) diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index a8a73e0..b190b4d 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -7,7 +7,7 @@ from uuid import uuid4 from redis import Redis from llm_server.cluster.cluster_config import cluster_config -from llm_server.custom_redis import RedisCustom +from llm_server.custom_redis import RedisCustom, redis from llm_server.llm.generator import generator from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count @@ -20,15 +20,25 @@ def get_stream_name(name: str): return f'{STREAM_NAME_PREFIX}:{name}' -def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str): +def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str): prompt = msg_to_backend['prompt'] stream_name = get_stream_name(stream_name) + redis.delete(f'notifications:{event_id}') stream_redis.delete(get_stream_name(stream_name)) # be extra sure try: response = generator(msg_to_backend, backend_url) generated_text = '' partial_response = b'' for chunk in response.iter_content(chunk_size=1): + # If there is no more data, break the loop + if not chunk: + break + message = redis.lpop(f'notifications:{event_id}') + if message and message.decode('utf-8') == 'canceled': + print('Client canceled generation') + response.close() + return + partial_response += chunk if partial_response.endswith(b'\x00'): json_strs = partial_response.split(b'\x00') @@ -74,14 +84,16 @@ def worker(backend_url): try: if do_stream: + # Return the name of the stream that the slave should connect to. event = DataEvent(event_id) event.set(get_stream_name(worker_id)) + msg_to_backend = { **parameters, 'prompt': request_json_body['prompt'], 'stream': True, } - inference_do_stream(worker_id, msg_to_backend, backend_url) + inference_do_stream(worker_id, msg_to_backend, backend_url, event_id) else: # Normal inference (not streaming). success, response, error_msg = generator(request_json_body, backend_url) diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py index c28e383..4025df3 100644 --- a/llm_server/workers/printer.py +++ b/llm_server/workers/printer.py @@ -29,4 +29,4 @@ def console_printer(): # TODO: Active Workers and Processing should read the same. If not, that's an issue logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}') - time.sleep(10) + time.sleep(2) diff --git a/other/ooba-test-streaming.py b/other/ooba-test-streaming.py index 883c2f5..7f5185d 100644 --- a/other/ooba-test-streaming.py +++ b/other/ooba-test-streaming.py @@ -11,6 +11,7 @@ except ImportError: HOST = 'localhost:5000' URI = f'ws://{HOST}/api/v1/stream' + # For reverse-proxied streaming, the remote will likely host with ssl - wss:// # URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' @@ -82,5 +83,6 @@ async def print_response_stream(prompt): if __name__ == '__main__': - prompt = "In order to make homemade bread, follow these steps:\n1)" + # prompt = "In order to make homemade bread, follow these steps:\n1)" + prompt = "Write a 300 word description of how an apple tree grows.\n\n" asyncio.run(print_response_stream(prompt)) diff --git a/server.py b/server.py index 490eebe..e33d55a 100644 --- a/server.py +++ b/server.py @@ -28,7 +28,7 @@ from llm_server.routes.openai import openai_bp, openai_model_bp from llm_server.routes.server_error import handle_server_error from llm_server.routes.v1 import bp from llm_server.routes.v1.generate_stats import generate_stats -from llm_server.sock import init_socketio +from llm_server.sock import init_wssocket # TODO: queue item timeout # TODO: return an `error: True`, error code, and error message rather than just a formatted message @@ -68,10 +68,15 @@ except ModuleNotFoundError as e: sys.exit(1) app = Flask(__name__) + +# Fixes ConcurrentObjectUseError +# https://github.com/miguelgrinberg/simple-websocket/issues/24 +app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25} + app.register_blueprint(bp, url_prefix='/api/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') app.register_blueprint(openai_model_bp, url_prefix='/api/openai/') -init_socketio(app) +init_wssocket(app) flask_cache.init_app(app) flask_cache.clear()