begin streaming rewrite
This commit is contained in:
parent
24aab3cd93
commit
151b3e4769
|
@ -1,8 +1,10 @@
|
|||
import json
|
||||
import pickle
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
from . import openai_bp, openai_model_bp
|
||||
|
@ -11,7 +13,6 @@ from ..openai_request_handler import OpenAIRequestHandler
|
|||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.generator import generator
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||
|
||||
|
@ -64,24 +65,18 @@ def openai_chat_completions(model_name=None):
|
|||
# Prevent issues on the backend.
|
||||
return 'Invalid prompt', 400
|
||||
|
||||
event_id = None
|
||||
# Need to set the prompt in the JSON body since that's what the inference worker expects.
|
||||
handler.request_json_body['prompt'] = handler.prompt
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
else:
|
||||
msg_to_backend = {
|
||||
**handler.parameters,
|
||||
'prompt': handler.prompt,
|
||||
'stream': True,
|
||||
}
|
||||
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
start_time = time.time()
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
|
||||
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||
if not event:
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
|
@ -97,27 +92,6 @@ def openai_chat_completions(model_name=None):
|
|||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
# Once the worker receives our streaming request, it will tell us we are ready
|
||||
# to begin inference.
|
||||
event_id = event.event_id
|
||||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(event_id)
|
||||
for item in pubsub.listen():
|
||||
if time.time() - start_time >= opts.backend_generate_request_timeout:
|
||||
raise Exception('Inferencer timed out waiting for streaming to complete:', request_json_body)
|
||||
if item['type'] == 'message':
|
||||
msg = item['data'].decode('utf-8')
|
||||
if msg == 'begin':
|
||||
break
|
||||
elif msg == 'offline':
|
||||
# This shouldn't happen because the best model should be auto-selected.
|
||||
return return_invalid_model_err(handler.request_json_body['model'])
|
||||
time.sleep(0.1)
|
||||
|
||||
# Double check the model is still online
|
||||
if not handler.check_online():
|
||||
return return_invalid_model_err(handler.request_json_body['model'])
|
||||
|
||||
try:
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
|
@ -125,68 +99,62 @@ def openai_chat_completions(model_name=None):
|
|||
oai_string = generate_oai_string(30)
|
||||
|
||||
def generate():
|
||||
stream_name = event.wait()
|
||||
stream_redis = Redis(db=8)
|
||||
generated_text = ''
|
||||
try:
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
redis.publish(event_id, 'chunk') # Keepalive
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
data = {
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: '0-0'}, block=30000)
|
||||
if not stream_data:
|
||||
print("No message received in 30 seconds, closing stream.")
|
||||
yield 'data: [DONE]\n\n'
|
||||
else:
|
||||
for r_timestamp, item in stream_data[0][1]:
|
||||
timestamp = int(r_timestamp.decode('utf-8').split('-')[0])
|
||||
data = pickle.loads(item[b'data'])
|
||||
if data['error']:
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
elif data['new']:
|
||||
response = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"created": timestamp,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
"content": data['new']
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
200,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
except GeneratorExit:
|
||||
yield 'data: [DONE]\n\n'
|
||||
except:
|
||||
# AttributeError: 'bool' object has no attribute 'iter_content'
|
||||
generated_text = generated_text + data['new']
|
||||
yield f'data: {json.dumps(response)}\n\n'
|
||||
elif data['completed']:
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
200,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return
|
||||
except (Exception, GeneratorExit):
|
||||
traceback.print_exc()
|
||||
yield 'data: [DONE]\n\n'
|
||||
finally:
|
||||
# After completing inference, we need to tell the worker we are finished.
|
||||
if event_id: # may be None if ratelimited.
|
||||
redis.publish(event_id, 'finished')
|
||||
else:
|
||||
print('event_id was None!')
|
||||
stream_redis.delete(stream_name)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
|
|
|
@ -27,7 +27,7 @@ class RedisPriorityQueue:
|
|||
self.name = name
|
||||
self.redis = RedisCustom(name, db=db)
|
||||
|
||||
def put(self, item, priority, selected_model):
|
||||
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
|
||||
assert item is not None
|
||||
assert priority is not None
|
||||
assert selected_model is not None
|
||||
|
@ -43,7 +43,7 @@ class RedisPriorityQueue:
|
|||
return None # reject the request
|
||||
|
||||
timestamp = time.time()
|
||||
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp)): -priority})
|
||||
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
|
||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||
return event
|
||||
|
||||
|
@ -106,6 +106,7 @@ class DataEvent:
|
|||
self.redis.publish(self.event_id, pickle.dumps(data))
|
||||
|
||||
def wait(self):
|
||||
# TODO: implement timeout
|
||||
for item in self.pubsub.listen():
|
||||
if item['type'] == 'message':
|
||||
return pickle.loads(item['data'])
|
||||
|
@ -151,9 +152,9 @@ class PriorityQueue:
|
|||
count += queue.get_queued_ip_count(client_ip)
|
||||
return count
|
||||
|
||||
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str):
|
||||
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
|
||||
queue = RedisPriorityQueue(backend_url)
|
||||
return queue.put(item, priority, selected_model)
|
||||
return queue.put(item, priority, selected_model, do_stream)
|
||||
|
||||
def activity(self):
|
||||
lines = []
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
import time
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
|
||||
|
||||
|
||||
# NOT NEEDED
|
||||
|
||||
def cleaner():
|
||||
r = Redis(db=8)
|
||||
stream_info = {}
|
||||
|
||||
while True:
|
||||
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
|
||||
processed_streams = []
|
||||
for stream in all_streams:
|
||||
stream = stream.decode()
|
||||
current_size = r.xlen(stream)
|
||||
|
||||
# If the stream is new or its size has changed, update the size and time in the dictionary
|
||||
if stream not in stream_info or current_size != stream_info[stream]['size']:
|
||||
stream_info[stream] = {'size': current_size, 'time': time.time()}
|
||||
processed_streams.append(stream)
|
||||
else:
|
||||
# If the size hasn't changed for 5 minutes, delete the stream
|
||||
if time.time() - stream_info[stream]['time'] >= 300:
|
||||
r.delete(stream)
|
||||
print(f"Stream '{stream}' deleted due to inactivity.")
|
||||
del stream_info[stream]
|
||||
|
||||
time.sleep(60)
|
|
@ -1,90 +1,88 @@
|
|||
import queue
|
||||
import json
|
||||
import pickle
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from uuid import uuid4
|
||||
|
||||
from redis.client import PubSub
|
||||
from redis import Redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
from llm_server.custom_redis import RedisCustom
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
||||
|
||||
stream_redis = Redis(db=8)
|
||||
|
||||
class ListenerThread(threading.Thread):
|
||||
def __init__(self, pubsub: PubSub, listener_queue: queue.Queue, stop_event: threading.Event):
|
||||
threading.Thread.__init__(self)
|
||||
self.pubsub = pubsub
|
||||
self.listener_queue = listener_queue
|
||||
self.stop_event = stop_event
|
||||
STREAM_NAME_PREFIX = 'stream'
|
||||
|
||||
def run(self):
|
||||
while not self.stop_event.is_set():
|
||||
message = self.pubsub.get_message()
|
||||
if message:
|
||||
self.listener_queue.put(message)
|
||||
time.sleep(0.1)
|
||||
|
||||
def get_stream_name(name: str):
|
||||
return f'{STREAM_NAME_PREFIX}:{name}'
|
||||
|
||||
|
||||
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str):
|
||||
prompt = msg_to_backend['prompt']
|
||||
stream_name = get_stream_name(stream_name)
|
||||
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
|
||||
try:
|
||||
response = generator(msg_to_backend, backend_url)
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': new, 'completed': False, 'error': None})})
|
||||
except Exception as e:
|
||||
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
# Publish final message to Redis stream
|
||||
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': None})})
|
||||
|
||||
|
||||
def worker(backend_url):
|
||||
status_redis = RedisCustom('worker_status')
|
||||
worker_id = uuid4()
|
||||
worker_id = str(uuid4())
|
||||
status_redis.setp(str(worker_id), None)
|
||||
redis_queue = RedisPriorityQueue(backend_url)
|
||||
while True:
|
||||
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp = redis_queue.get()
|
||||
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
|
||||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(event_id)
|
||||
stop_event = threading.Event()
|
||||
q = queue.Queue()
|
||||
listener = ListenerThread(pubsub, q, stop_event)
|
||||
listener.start()
|
||||
|
||||
if not backend_info['online']:
|
||||
redis.publish(event_id, 'offline')
|
||||
# TODO: communicate to caller
|
||||
# redis.publish(event_id, 'offline')
|
||||
return
|
||||
|
||||
if not selected_model:
|
||||
selected_model = backend_info['model']
|
||||
|
||||
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
|
||||
increment_ip_count(client_ip, 'processing_ips')
|
||||
incr_active_workers(selected_model, backend_url)
|
||||
|
||||
status_redis.setp(str(worker_id), (backend_url, client_ip))
|
||||
status_redis.setp(str(worker_id), ('generating', client_ip))
|
||||
|
||||
try:
|
||||
if not request_json_body:
|
||||
# This was a dummy request from the streaming handlers.
|
||||
# The worker will let the handler do the streaming instead
|
||||
# of the worker. The worker will block until the handler
|
||||
# is finished. Since a lot of ratelimiting and stats are
|
||||
# based off the number of active workers, we must keep
|
||||
# the generation based off the workers.
|
||||
start_time = time.time()
|
||||
redis.publish(event_id, 'begin')
|
||||
while True:
|
||||
status_redis.setp(str(worker_id), (f'waiting for streaming to complete - {time.time() - start_time} - {opts.backend_generate_request_timeout}', client_ip))
|
||||
|
||||
try:
|
||||
item = q.get(timeout=30)
|
||||
except queue.Empty:
|
||||
print('Inferencer timed out waiting for chunk from streamer:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
|
||||
status_redis.setp(str(worker_id), ('streaming chunk timed out', client_ip))
|
||||
break
|
||||
|
||||
if time.time() - start_time >= opts.backend_generate_request_timeout:
|
||||
status_redis.setp(str(worker_id), ('streaming timed out', client_ip))
|
||||
print('Inferencer timed out waiting for streaming to complete:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
|
||||
break
|
||||
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
|
||||
status_redis.setp(str(worker_id), ('streaming completed', client_ip))
|
||||
break
|
||||
if do_stream:
|
||||
event = DataEvent(event_id)
|
||||
event.set(get_stream_name(worker_id))
|
||||
msg_to_backend = {
|
||||
**parameters,
|
||||
'prompt': request_json_body['prompt'],
|
||||
'stream': True,
|
||||
}
|
||||
inference_do_stream(worker_id, msg_to_backend, backend_url)
|
||||
else:
|
||||
status_redis.setp(str(worker_id), ('generating', client_ip))
|
||||
# Normal inference (not streaming).
|
||||
success, response, error_msg = generator(request_json_body, backend_url)
|
||||
event = DataEvent(event_id)
|
||||
|
@ -92,8 +90,6 @@ def worker(backend_url):
|
|||
except:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
stop_event.set() # make sure to stop the listener thread
|
||||
listener.join()
|
||||
decrement_ip_count(client_ip, 'processing_ips')
|
||||
decr_active_workers(selected_model, backend_url)
|
||||
status_redis.setp(str(worker_id), None)
|
||||
|
|
|
@ -2,7 +2,6 @@ import time
|
|||
from threading import Thread
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.stores import redis_running_models
|
||||
from llm_server.cluster.worker import cluster_worker
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.workers.inferencer import start_workers
|
||||
|
|
Reference in New Issue