From 806e522d166baff7ade0973d55cbbe3e46f7f5a2 Mon Sep 17 00:00:00 2001 From: Cyberes Date: Mon, 16 Oct 2023 18:35:10 -0600 Subject: [PATCH] don't pickle streaming --- llm_server/opts.py | 3 ++- llm_server/routes/openai/chat_completions.py | 8 ++++---- llm_server/routes/openai/completions.py | 8 ++++---- llm_server/routes/v1/generate_stream.py | 8 ++++---- llm_server/workers/inferencer.py | 8 ++++---- requirements.txt | 3 ++- 6 files changed, 20 insertions(+), 18 deletions(-) diff --git a/llm_server/opts.py b/llm_server/opts.py index 69b25eb..ada54a8 100644 --- a/llm_server/opts.py +++ b/llm_server/opts.py @@ -37,4 +37,5 @@ show_backends = True background_homepage_cacher = True openai_moderation_timeout = 5 prioritize_by_size = False -cluster_workers = 0 \ No newline at end of file +cluster_workers = 0 +redis_stream_timeout = 25000 diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index e863636..5e5921a 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -1,5 +1,5 @@ import json -import pickle +import ujson import time import traceback @@ -104,15 +104,15 @@ def openai_chat_completions(model_name=None): try: last_id = '0-0' while True: - stream_data = stream_redis.xread({stream_name: last_id}, block=30000) + stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) if not stream_data: - print("No message received in 30 seconds, closing stream.") + print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") yield 'data: [DONE]\n\n' else: for stream_index, item in stream_data[0][1]: last_id = stream_index timestamp = int(stream_index.decode('utf-8').split('-')[0]) - data = pickle.loads(item[b'data']) + data = ujson.loads(item[b'data']) if data['error']: yield 'data: [DONE]\n\n' return diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index 4df336f..b8efb07 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -1,8 +1,8 @@ -import pickle import time import traceback import simplejson as json +import ujson from flask import Response, jsonify, request from redis import Redis @@ -150,15 +150,15 @@ def openai_completions(model_name=None): try: last_id = '0-0' while True: - stream_data = stream_redis.xread({stream_name: last_id}, block=30000) + stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) if not stream_data: - print("No message received in 30 seconds, closing stream.") + print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") yield 'data: [DONE]\n\n' else: for stream_index, item in stream_data[0][1]: last_id = stream_index timestamp = int(stream_index.decode('utf-8').split('-')[0]) - data = pickle.loads(item[b'data']) + data = ujson.loads(item[b'data']) if data['error']: yield 'data: [DONE]\n\n' return diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 7c02cc9..29eb281 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -1,8 +1,8 @@ import json -import pickle import time import traceback +import ujson from flask import request from redis import Redis @@ -136,14 +136,14 @@ def do_stream(ws, model_name): try: last_id = '0-0' # The ID of the last entry we read. while True: - stream_data = stream_redis.xread({stream_name: last_id}, block=30000) + stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout) if not stream_data: - print("No message received in 30 seconds, closing stream.") + print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.") return else: for stream_index, item in stream_data[0][1]: last_id = stream_index - data = pickle.loads(item[b'data']) + data = ujson.loads(item[b'data']) if data['error']: print(data['error']) send_err_and_quit('Encountered exception while streaming.') diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index b190b4d..b6e94dc 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -1,9 +1,9 @@ import json -import pickle import threading import traceback from uuid import uuid4 +import ujson from redis import Redis from llm_server.cluster.cluster_config import cluster_config @@ -51,13 +51,13 @@ def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str except IndexError: # ???? continue - stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': new, 'completed': False, 'error': None})}) + stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})}) except Exception as e: - stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})}) + stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})}) traceback.print_exc() finally: # Publish final message to Redis stream - stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': None})}) + stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})}) def worker(backend_url): diff --git a/requirements.txt b/requirements.txt index 28e818f..802d6f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ openai~=0.28.0 urllib3~=2.0.4 flask-sock==0.6.0 gunicorn==21.2.0 -redis==5.0.1 \ No newline at end of file +redis==5.0.1 +ujson==5.8.0 \ No newline at end of file