begin streaming rewrite
This commit is contained in:
parent
24aab3cd93
commit
151b3e4769
|
@ -1,8 +1,10 @@
|
||||||
import json
|
import json
|
||||||
|
import pickle
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from flask import Response, jsonify, request
|
from flask import Response, jsonify, request
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from . import openai_bp, openai_model_bp
|
from . import openai_bp, openai_model_bp
|
||||||
|
@ -11,7 +13,6 @@ from ..openai_request_handler import OpenAIRequestHandler
|
||||||
from ..queue import priority_queue
|
from ..queue import priority_queue
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm.generator import generator
|
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||||
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
from ...llm.openai.transform import generate_oai_string, transform_messages_to_prompt, trim_messages_to_fit
|
||||||
|
|
||||||
|
@ -64,24 +65,18 @@ def openai_chat_completions(model_name=None):
|
||||||
# Prevent issues on the backend.
|
# Prevent issues on the backend.
|
||||||
return 'Invalid prompt', 400
|
return 'Invalid prompt', 400
|
||||||
|
|
||||||
event_id = None
|
# Need to set the prompt in the JSON body since that's what the inference worker expects.
|
||||||
|
handler.request_json_body['prompt'] = handler.prompt
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
request_valid, invalid_response = handler.validate_request()
|
request_valid, invalid_response = handler.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
return invalid_response
|
return invalid_response
|
||||||
else:
|
else:
|
||||||
msg_to_backend = {
|
|
||||||
**handler.parameters,
|
|
||||||
'prompt': handler.prompt,
|
|
||||||
'stream': True,
|
|
||||||
}
|
|
||||||
|
|
||||||
event = None
|
event = None
|
||||||
if not handler.is_client_ratelimited():
|
if not handler.is_client_ratelimited():
|
||||||
start_time = time.time()
|
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||||
# Add a dummy event to the queue and wait for it to reach a worker
|
|
||||||
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
|
|
||||||
if not event:
|
if not event:
|
||||||
log_to_db(
|
log_to_db(
|
||||||
handler.client_ip,
|
handler.client_ip,
|
||||||
|
@ -97,27 +92,6 @@ def openai_chat_completions(model_name=None):
|
||||||
)
|
)
|
||||||
return handler.handle_ratelimited()
|
return handler.handle_ratelimited()
|
||||||
|
|
||||||
# Once the worker receives our streaming request, it will tell us we are ready
|
|
||||||
# to begin inference.
|
|
||||||
event_id = event.event_id
|
|
||||||
pubsub = redis.pubsub()
|
|
||||||
pubsub.subscribe(event_id)
|
|
||||||
for item in pubsub.listen():
|
|
||||||
if time.time() - start_time >= opts.backend_generate_request_timeout:
|
|
||||||
raise Exception('Inferencer timed out waiting for streaming to complete:', request_json_body)
|
|
||||||
if item['type'] == 'message':
|
|
||||||
msg = item['data'].decode('utf-8')
|
|
||||||
if msg == 'begin':
|
|
||||||
break
|
|
||||||
elif msg == 'offline':
|
|
||||||
# This shouldn't happen because the best model should be auto-selected.
|
|
||||||
return return_invalid_model_err(handler.request_json_body['model'])
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
# Double check the model is still online
|
|
||||||
if not handler.check_online():
|
|
||||||
return return_invalid_model_err(handler.request_json_body['model'])
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
r_url = request.url
|
r_url = request.url
|
||||||
|
@ -125,68 +99,62 @@ def openai_chat_completions(model_name=None):
|
||||||
oai_string = generate_oai_string(30)
|
oai_string = generate_oai_string(30)
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
|
stream_name = event.wait()
|
||||||
|
stream_redis = Redis(db=8)
|
||||||
|
generated_text = ''
|
||||||
try:
|
try:
|
||||||
response = generator(msg_to_backend, handler.backend_url)
|
while True:
|
||||||
generated_text = ''
|
stream_data = stream_redis.xread({stream_name: '0-0'}, block=30000)
|
||||||
partial_response = b''
|
if not stream_data:
|
||||||
for chunk in response.iter_content(chunk_size=1):
|
print("No message received in 30 seconds, closing stream.")
|
||||||
partial_response += chunk
|
yield 'data: [DONE]\n\n'
|
||||||
if partial_response.endswith(b'\x00'):
|
else:
|
||||||
json_strs = partial_response.split(b'\x00')
|
for r_timestamp, item in stream_data[0][1]:
|
||||||
for json_str in json_strs:
|
timestamp = int(r_timestamp.decode('utf-8').split('-')[0])
|
||||||
if json_str:
|
data = pickle.loads(item[b'data'])
|
||||||
try:
|
if data['error']:
|
||||||
json_obj = json.loads(json_str.decode())
|
yield 'data: [DONE]\n\n'
|
||||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
return
|
||||||
generated_text = generated_text + new
|
elif data['new']:
|
||||||
redis.publish(event_id, 'chunk') # Keepalive
|
response = {
|
||||||
except IndexError:
|
|
||||||
# ????
|
|
||||||
continue
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"id": f"chatcmpl-{oai_string}",
|
"id": f"chatcmpl-{oai_string}",
|
||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"created": int(time.time()),
|
"created": timestamp,
|
||||||
"model": model,
|
"model": model,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": new
|
"content": data['new']
|
||||||
},
|
},
|
||||||
"finish_reason": None
|
"finish_reason": None
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
yield f'data: {json.dumps(data)}\n\n'
|
generated_text = generated_text + data['new']
|
||||||
yield 'data: [DONE]\n\n'
|
yield f'data: {json.dumps(response)}\n\n'
|
||||||
end_time = time.time()
|
elif data['completed']:
|
||||||
elapsed_time = end_time - start_time
|
yield 'data: [DONE]\n\n'
|
||||||
log_to_db(
|
end_time = time.time()
|
||||||
handler.client_ip,
|
elapsed_time = end_time - start_time
|
||||||
handler.token,
|
log_to_db(
|
||||||
handler.prompt,
|
handler.client_ip,
|
||||||
generated_text,
|
handler.token,
|
||||||
elapsed_time,
|
handler.prompt,
|
||||||
handler.parameters,
|
generated_text,
|
||||||
r_headers,
|
elapsed_time,
|
||||||
200,
|
handler.parameters,
|
||||||
r_url,
|
r_headers,
|
||||||
handler.backend_url,
|
200,
|
||||||
)
|
r_url,
|
||||||
except GeneratorExit:
|
handler.backend_url,
|
||||||
yield 'data: [DONE]\n\n'
|
)
|
||||||
except:
|
return
|
||||||
# AttributeError: 'bool' object has no attribute 'iter_content'
|
except (Exception, GeneratorExit):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
finally:
|
finally:
|
||||||
# After completing inference, we need to tell the worker we are finished.
|
stream_redis.delete(stream_name)
|
||||||
if event_id: # may be None if ratelimited.
|
|
||||||
redis.publish(event_id, 'finished')
|
|
||||||
else:
|
|
||||||
print('event_id was None!')
|
|
||||||
|
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -27,7 +27,7 @@ class RedisPriorityQueue:
|
||||||
self.name = name
|
self.name = name
|
||||||
self.redis = RedisCustom(name, db=db)
|
self.redis = RedisCustom(name, db=db)
|
||||||
|
|
||||||
def put(self, item, priority, selected_model):
|
def put(self, item, priority: int, selected_model: str, do_stream: bool = False):
|
||||||
assert item is not None
|
assert item is not None
|
||||||
assert priority is not None
|
assert priority is not None
|
||||||
assert selected_model is not None
|
assert selected_model is not None
|
||||||
|
@ -43,7 +43,7 @@ class RedisPriorityQueue:
|
||||||
return None # reject the request
|
return None # reject the request
|
||||||
|
|
||||||
timestamp = time.time()
|
timestamp = time.time()
|
||||||
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp)): -priority})
|
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp, do_stream)): -priority})
|
||||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@ -106,6 +106,7 @@ class DataEvent:
|
||||||
self.redis.publish(self.event_id, pickle.dumps(data))
|
self.redis.publish(self.event_id, pickle.dumps(data))
|
||||||
|
|
||||||
def wait(self):
|
def wait(self):
|
||||||
|
# TODO: implement timeout
|
||||||
for item in self.pubsub.listen():
|
for item in self.pubsub.listen():
|
||||||
if item['type'] == 'message':
|
if item['type'] == 'message':
|
||||||
return pickle.loads(item['data'])
|
return pickle.loads(item['data'])
|
||||||
|
@ -151,9 +152,9 @@ class PriorityQueue:
|
||||||
count += queue.get_queued_ip_count(client_ip)
|
count += queue.get_queued_ip_count(client_ip)
|
||||||
return count
|
return count
|
||||||
|
|
||||||
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str):
|
def put(self, backend_url, item: Tuple[dict, str, str, dict], priority: int, selected_model: str, do_stream: bool = False):
|
||||||
queue = RedisPriorityQueue(backend_url)
|
queue = RedisPriorityQueue(backend_url)
|
||||||
return queue.put(item, priority, selected_model)
|
return queue.put(item, priority, selected_model, do_stream)
|
||||||
|
|
||||||
def activity(self):
|
def activity(self):
|
||||||
lines = []
|
lines = []
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
import time
|
||||||
|
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
|
from llm_server.workers.inferencer import STREAM_NAME_PREFIX
|
||||||
|
|
||||||
|
|
||||||
|
# NOT NEEDED
|
||||||
|
|
||||||
|
def cleaner():
|
||||||
|
r = Redis(db=8)
|
||||||
|
stream_info = {}
|
||||||
|
|
||||||
|
while True:
|
||||||
|
all_streams = r.keys(f'{STREAM_NAME_PREFIX}:*')
|
||||||
|
processed_streams = []
|
||||||
|
for stream in all_streams:
|
||||||
|
stream = stream.decode()
|
||||||
|
current_size = r.xlen(stream)
|
||||||
|
|
||||||
|
# If the stream is new or its size has changed, update the size and time in the dictionary
|
||||||
|
if stream not in stream_info or current_size != stream_info[stream]['size']:
|
||||||
|
stream_info[stream] = {'size': current_size, 'time': time.time()}
|
||||||
|
processed_streams.append(stream)
|
||||||
|
else:
|
||||||
|
# If the size hasn't changed for 5 minutes, delete the stream
|
||||||
|
if time.time() - stream_info[stream]['time'] >= 300:
|
||||||
|
r.delete(stream)
|
||||||
|
print(f"Stream '{stream}' deleted due to inactivity.")
|
||||||
|
del stream_info[stream]
|
||||||
|
|
||||||
|
time.sleep(60)
|
|
@ -1,90 +1,88 @@
|
||||||
import queue
|
import json
|
||||||
|
import pickle
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from redis.client import PubSub
|
from redis import Redis
|
||||||
|
|
||||||
from llm_server import opts
|
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.custom_redis import RedisCustom, redis
|
from llm_server.custom_redis import RedisCustom
|
||||||
from llm_server.llm.generator import generator
|
from llm_server.llm.generator import generator
|
||||||
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
||||||
|
|
||||||
|
stream_redis = Redis(db=8)
|
||||||
|
|
||||||
class ListenerThread(threading.Thread):
|
STREAM_NAME_PREFIX = 'stream'
|
||||||
def __init__(self, pubsub: PubSub, listener_queue: queue.Queue, stop_event: threading.Event):
|
|
||||||
threading.Thread.__init__(self)
|
|
||||||
self.pubsub = pubsub
|
|
||||||
self.listener_queue = listener_queue
|
|
||||||
self.stop_event = stop_event
|
|
||||||
|
|
||||||
def run(self):
|
|
||||||
while not self.stop_event.is_set():
|
def get_stream_name(name: str):
|
||||||
message = self.pubsub.get_message()
|
return f'{STREAM_NAME_PREFIX}:{name}'
|
||||||
if message:
|
|
||||||
self.listener_queue.put(message)
|
|
||||||
time.sleep(0.1)
|
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str):
|
||||||
|
prompt = msg_to_backend['prompt']
|
||||||
|
stream_name = get_stream_name(stream_name)
|
||||||
|
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
|
||||||
|
try:
|
||||||
|
response = generator(msg_to_backend, backend_url)
|
||||||
|
generated_text = ''
|
||||||
|
partial_response = b''
|
||||||
|
for chunk in response.iter_content(chunk_size=1):
|
||||||
|
partial_response += chunk
|
||||||
|
if partial_response.endswith(b'\x00'):
|
||||||
|
json_strs = partial_response.split(b'\x00')
|
||||||
|
for json_str in json_strs:
|
||||||
|
if json_str:
|
||||||
|
try:
|
||||||
|
json_obj = json.loads(json_str.decode())
|
||||||
|
new = json_obj['text'][0].split(prompt + generated_text)[1]
|
||||||
|
generated_text = generated_text + new
|
||||||
|
except IndexError:
|
||||||
|
# ????
|
||||||
|
continue
|
||||||
|
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': new, 'completed': False, 'error': None})})
|
||||||
|
except Exception as e:
|
||||||
|
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': f'{e.__class__.__name__}: {e}'})})
|
||||||
|
traceback.print_exc()
|
||||||
|
finally:
|
||||||
|
# Publish final message to Redis stream
|
||||||
|
stream_redis.xadd(stream_name, {'data': pickle.dumps({'new': None, 'completed': True, 'error': None})})
|
||||||
|
|
||||||
|
|
||||||
def worker(backend_url):
|
def worker(backend_url):
|
||||||
status_redis = RedisCustom('worker_status')
|
status_redis = RedisCustom('worker_status')
|
||||||
worker_id = uuid4()
|
worker_id = str(uuid4())
|
||||||
status_redis.setp(str(worker_id), None)
|
status_redis.setp(str(worker_id), None)
|
||||||
redis_queue = RedisPriorityQueue(backend_url)
|
redis_queue = RedisPriorityQueue(backend_url)
|
||||||
while True:
|
while True:
|
||||||
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp = redis_queue.get()
|
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp, do_stream = redis_queue.get()
|
||||||
backend_info = cluster_config.get_backend(backend_url)
|
backend_info = cluster_config.get_backend(backend_url)
|
||||||
|
|
||||||
pubsub = redis.pubsub()
|
|
||||||
pubsub.subscribe(event_id)
|
|
||||||
stop_event = threading.Event()
|
|
||||||
q = queue.Queue()
|
|
||||||
listener = ListenerThread(pubsub, q, stop_event)
|
|
||||||
listener.start()
|
|
||||||
|
|
||||||
if not backend_info['online']:
|
if not backend_info['online']:
|
||||||
redis.publish(event_id, 'offline')
|
# TODO: communicate to caller
|
||||||
|
# redis.publish(event_id, 'offline')
|
||||||
return
|
return
|
||||||
|
|
||||||
if not selected_model:
|
if not selected_model:
|
||||||
selected_model = backend_info['model']
|
selected_model = backend_info['model']
|
||||||
|
|
||||||
|
stream_redis.delete(get_stream_name(worker_id)) # clean up any old streams
|
||||||
increment_ip_count(client_ip, 'processing_ips')
|
increment_ip_count(client_ip, 'processing_ips')
|
||||||
incr_active_workers(selected_model, backend_url)
|
incr_active_workers(selected_model, backend_url)
|
||||||
|
status_redis.setp(str(worker_id), ('generating', client_ip))
|
||||||
status_redis.setp(str(worker_id), (backend_url, client_ip))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not request_json_body:
|
if do_stream:
|
||||||
# This was a dummy request from the streaming handlers.
|
event = DataEvent(event_id)
|
||||||
# The worker will let the handler do the streaming instead
|
event.set(get_stream_name(worker_id))
|
||||||
# of the worker. The worker will block until the handler
|
msg_to_backend = {
|
||||||
# is finished. Since a lot of ratelimiting and stats are
|
**parameters,
|
||||||
# based off the number of active workers, we must keep
|
'prompt': request_json_body['prompt'],
|
||||||
# the generation based off the workers.
|
'stream': True,
|
||||||
start_time = time.time()
|
}
|
||||||
redis.publish(event_id, 'begin')
|
inference_do_stream(worker_id, msg_to_backend, backend_url)
|
||||||
while True:
|
|
||||||
status_redis.setp(str(worker_id), (f'waiting for streaming to complete - {time.time() - start_time} - {opts.backend_generate_request_timeout}', client_ip))
|
|
||||||
|
|
||||||
try:
|
|
||||||
item = q.get(timeout=30)
|
|
||||||
except queue.Empty:
|
|
||||||
print('Inferencer timed out waiting for chunk from streamer:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
|
|
||||||
status_redis.setp(str(worker_id), ('streaming chunk timed out', client_ip))
|
|
||||||
break
|
|
||||||
|
|
||||||
if time.time() - start_time >= opts.backend_generate_request_timeout:
|
|
||||||
status_redis.setp(str(worker_id), ('streaming timed out', client_ip))
|
|
||||||
print('Inferencer timed out waiting for streaming to complete:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
|
|
||||||
break
|
|
||||||
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
|
|
||||||
status_redis.setp(str(worker_id), ('streaming completed', client_ip))
|
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
status_redis.setp(str(worker_id), ('generating', client_ip))
|
|
||||||
# Normal inference (not streaming).
|
# Normal inference (not streaming).
|
||||||
success, response, error_msg = generator(request_json_body, backend_url)
|
success, response, error_msg = generator(request_json_body, backend_url)
|
||||||
event = DataEvent(event_id)
|
event = DataEvent(event_id)
|
||||||
|
@ -92,8 +90,6 @@ def worker(backend_url):
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
stop_event.set() # make sure to stop the listener thread
|
|
||||||
listener.join()
|
|
||||||
decrement_ip_count(client_ip, 'processing_ips')
|
decrement_ip_count(client_ip, 'processing_ips')
|
||||||
decr_active_workers(selected_model, backend_url)
|
decr_active_workers(selected_model, backend_url)
|
||||||
status_redis.setp(str(worker_id), None)
|
status_redis.setp(str(worker_id), None)
|
||||||
|
|
|
@ -2,7 +2,6 @@ import time
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.cluster.stores import redis_running_models
|
|
||||||
from llm_server.cluster.worker import cluster_worker
|
from llm_server.cluster.worker import cluster_worker
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.workers.inferencer import start_workers
|
from llm_server.workers.inferencer import start_workers
|
||||||
|
|
Reference in New Issue