don't pickle streaming

This commit is contained in:
Cyberes 2023-10-16 18:35:10 -06:00
parent 81baf9616f
commit 806e522d16
6 changed files with 20 additions and 18 deletions

View File

@ -38,3 +38,4 @@ background_homepage_cacher = True
openai_moderation_timeout = 5 openai_moderation_timeout = 5
prioritize_by_size = False prioritize_by_size = False
cluster_workers = 0 cluster_workers = 0
redis_stream_timeout = 25000

View File

@ -1,5 +1,5 @@
import json import json
import pickle import ujson
import time import time
import traceback import traceback
@ -104,15 +104,15 @@ def openai_chat_completions(model_name=None):
try: try:
last_id = '0-0' last_id = '0-0'
while True: while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=30000) stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data: if not stream_data:
print("No message received in 30 seconds, closing stream.") print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
else: else:
for stream_index, item in stream_data[0][1]: for stream_index, item in stream_data[0][1]:
last_id = stream_index last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0]) timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = pickle.loads(item[b'data']) data = ujson.loads(item[b'data'])
if data['error']: if data['error']:
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
return return

View File

@ -1,8 +1,8 @@
import pickle
import time import time
import traceback import traceback
import simplejson as json import simplejson as json
import ujson
from flask import Response, jsonify, request from flask import Response, jsonify, request
from redis import Redis from redis import Redis
@ -150,15 +150,15 @@ def openai_completions(model_name=None):
try: try:
last_id = '0-0' last_id = '0-0'
while True: while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=30000) stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data: if not stream_data:
print("No message received in 30 seconds, closing stream.") print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
else: else:
for stream_index, item in stream_data[0][1]: for stream_index, item in stream_data[0][1]:
last_id = stream_index last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0]) timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = pickle.loads(item[b'data']) data = ujson.loads(item[b'data'])
if data['error']: if data['error']:
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
return return

View File

@ -1,8 +1,8 @@
import json import json
import pickle
import time import time
import traceback import traceback
import ujson
from flask import request from flask import request
from redis import Redis from redis import Redis
@ -136,14 +136,14 @@ def do_stream(ws, model_name):
try: try:
last_id = '0-0' # The ID of the last entry we read. last_id = '0-0' # The ID of the last entry we read.
while True: while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=30000) stream_data = stream_redis.xread({stream_name: last_id}, block=opts.redis_stream_timeout)
if not stream_data: if not stream_data:
print("No message received in 30 seconds, closing stream.") print(f"No message received in {opts.redis_stream_timeout / 1000} seconds, closing stream.")
return return
else: else:
for stream_index, item in stream_data[0][1]: for stream_index, item in stream_data[0][1]:
last_id = stream_index last_id = stream_index
data = pickle.loads(item[b'data']) data = ujson.loads(item[b'data'])
if data['error']: if data['error']:
print(data['error']) print(data['error'])
send_err_and_quit('Encountered exception while streaming.') send_err_and_quit('Encountered exception while streaming.')

View File

@ -1,9 +1,9 @@
import json import json
import pickle
import threading import threading
import traceback import traceback
from uuid import uuid4 from uuid import uuid4
import ujson
from redis import Redis from redis import Redis
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
@ -51,13 +51,13 @@ def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str
except IndexError: except IndexError:
# ???? # ????
continue continue
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': new, 'completed': False, 'error': None})}) stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': new, 'completed': False, 'error': None})})
except Exception as e: except Exception as e:
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})}) stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
traceback.print_exc() traceback.print_exc()
finally: finally:
# Publish final message to Redis stream # Publish final message to Redis stream
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': None})}) stream_redis.xadd(stream_name, {'data': ujson.dumps({'new': None, 'completed': True, 'error': None})})
def worker(backend_url): def worker(backend_url):

View File

@ -14,3 +14,4 @@ urllib3~=2.0.4
flask-sock==0.6.0 flask-sock==0.6.0
gunicorn==21.2.0 gunicorn==21.2.0
redis==5.0.1 redis==5.0.1
ujson==5.8.0