get streaming working again

This commit is contained in:
Cyberes 2023-10-16 16:22:52 -06:00
parent 151b3e4769
commit 2c7773cc4f
13 changed files with 296 additions and 373 deletions

View File

@ -223,5 +223,14 @@ class RedisCustom(Redis):
self.flush()
return True
def lrange(self, name: str, start: int, end: int):
return self.redis.lrange(self._key(name), start, end)
def delete(self, *names: KeyT):
return self.redis.delete(*[self._key(i) for i in names])
def lpop(self, name: str, count: Optional[int] = None):
return self.redis.lpop(self._key(name), count)
redis = RedisCustom('local_llm')

View File

@ -5,6 +5,9 @@ from redis import Redis
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
r = Redis(host='localhost', port=6379, db=3)
data = {
'function': 'log_prompt',

View File

@ -1,52 +0,0 @@
import json
from datetime import datetime, timedelta
import requests
from llm_server import opts
def get_power_states():
gpu_num = 0
output = {}
while True:
url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state"
try:
response = requests.get(url, timeout=10)
if response.status_code != 200:
break
data = json.loads(response.text)
power_state_data = data['data'][0]
power_state = None
for i in range(1, len(power_state_data)):
if power_state_data[i] == 1:
power_state = data['labels'][i]
break
output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p'))
except Exception as e:
print('Failed to fetch Netdata metrics:', e)
return output
gpu_num += 1
return output
def get_gpu_wh(gpu_id: int):
chart_name = f"nvidia_smi.gpu{gpu_id}_power"
now = datetime.now()
one_hour_ago = now - timedelta(hours=1)
num_seconds = int((now - one_hour_ago).total_seconds())
params = {
"chart": chart_name,
"after": int(one_hour_ago.timestamp()),
"before": int(now.timestamp()),
"points": num_seconds,
"group": "second",
"format": "json",
"options": "absolute|jsonwrap"
}
response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10)
data = json.loads(response.text)
total_power_usage_watts = sum(point[1] for point in data['result']['data'])
# total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1)
total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3)
return total_power_usage_kwh

View File

@ -43,24 +43,23 @@ def openai_chat_completions(model_name=None):
if not opts.enable_streaming:
return 'Streaming disabled', 403
handler.parameters, _ = handler.get_parameters()
handler.request_json_body = {
'messages': handler.request_json_body['messages'],
'model': handler.request_json_body['model'],
**handler.parameters
}
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
handler.parameters, e = handler.get_parameters()
handler.request_json_body = {
'messages': handler.request_json_body['messages'],
'model': handler.request_json_body['model'],
**handler.parameters
}
if opts.openai_silent_trim:
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
else:
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
if not handler.prompt:
# Prevent issues on the backend.
return 'Invalid prompt', 400
@ -73,7 +72,7 @@ def openai_chat_completions(model_name=None):
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
else:
event = None
if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
@ -103,14 +102,16 @@ def openai_chat_completions(model_name=None):
stream_redis = Redis(db=8)
generated_text = ''
try:
last_id = '0-0'
while True:
stream_data = stream_redis.xread({stream_name: '0-0'}, block=30000)
stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
if not stream_data:
print("No message received in 30 seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for r_timestamp, item in stream_data[0][1]:
timestamp = int(r_timestamp.decode('utf-8').split('-')[0])
for stream_index, item in stream_data[0][1]:
last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = pickle.loads(item[b'data'])
if data['error']:
yield 'data: [DONE]\n\n'
@ -154,6 +155,8 @@ def openai_chat_completions(model_name=None):
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
if event:
redis.lpush(f'notifications:{event.event_id}', 'canceled')
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')

View File

@ -1,8 +1,10 @@
import pickle
import time
import traceback
import simplejson as json
from flask import Response, jsonify, request
from redis import Redis
from llm_server.custom_redis import redis
from . import openai_bp, openai_model_bp
@ -12,7 +14,6 @@ from ..queue import priority_queue
from ... import opts
from ...database.log_to_db import log_to_db
from ...llm import get_token_count
from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
@ -42,12 +43,14 @@ def openai_completions(model_name=None):
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else:
# The handle_request() call below will load the prompt so we don't have
# to do anything else here.
pass
handler.request_json_body['prompt'] = handler.prompt
if not request_json_body.get('stream'):
invalid_oai_err_msg = validate_oai(request_json_body)
if invalid_oai_err_msg:
@ -89,24 +92,36 @@ def openai_completions(model_name=None):
if not opts.enable_streaming:
return 'Streaming disabled', 403
event_id = None
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
handler.parameters, _ = handler.get_parameters()
handler.request_json_body = {
'prompt': handler.request_json_body['prompt'],
'model': handler.request_json_body['model'],
**handler.parameters
}
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
if not handler.prompt:
# Prevent issues on the backend.
return 'Invalid prompt', 400
start_time = time.time()
request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
else:
handler.prompt = handler.request_json_body['prompt']
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
event = None
if not handler.is_client_ratelimited():
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
log_to_db(
handler.client_ip,
@ -122,68 +137,53 @@ def openai_completions(model_name=None):
)
return handler.handle_ratelimited()
# Wait for permission to begin.
event_id = event.event_id
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
for item in pubsub.listen():
if item['type'] == 'message':
msg = item['data'].decode('utf-8')
if msg == 'begin':
break
elif msg == 'offline':
return return_invalid_model_err(handler.request_json_body['model'])
time.sleep(0.1)
# Double check the model is still online
if not handler.check_online():
return return_invalid_model_err(handler.request_json_body['model'])
try:
response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers)
r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
def generate():
try:
stream_name = event.wait()
stream_redis = Redis(db=8)
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
data = {
last_id = '0-0'
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
if not stream_data:
print("No message received in 30 seconds, closing stream.")
yield 'data: [DONE]\n\n'
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = pickle.loads(item[b'data'])
if data['error']:
yield 'data: [DONE]\n\n'
return
elif data['new']:
response = {
"id": f"cmpl-{oai_string}",
"object": "text_completion",
"created": int(time.time()),
"created": timestamp,
"model": model,
"choices": [
{
"index": 0,
"delta": {
"content": new
"content": data['new']
},
"finish_reason": None
}
]
}
yield f'data: {json.dumps(data)}\n\n'
generated_text = generated_text + data['new']
yield f'data: {json.dumps(response)}\n\n'
elif data['completed']:
yield 'data: [DONE]\n\n'
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(
handler.client_ip,
handler.token,
@ -196,11 +196,14 @@ def openai_completions(model_name=None):
r_url,
handler.backend_url,
)
return
except (Exception, GeneratorExit):
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
if event_id:
redis.publish(event_id, 'finished')
else:
print('event_id was None!')
if event:
redis.lpush(f'notifications:{event.event_id}', 'canceled')
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')
except Exception:

View File

@ -150,10 +150,6 @@ class OpenAIRequestHandler(RequestHandler):
"total_tokens": prompt_tokens + response_tokens
}
}), 200)
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:

View File

@ -37,6 +37,9 @@ class RequestHandler:
self.parameters = None
self.used = False
# This is null by default since most handlers need to transform the prompt in a specific way.
self.prompt = None
self.selected_model = selected_model
self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)

View File

@ -1,17 +1,18 @@
import json
import pickle
import time
import traceback
from flask import request
from redis import Redis
from . import bp
from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler
from ..queue import priority_queue
from ... import messages, opts
from ... import opts
from ...custom_redis import redis
from ...database.log_to_db import log_to_db
from ...llm.generator import generator
from ...sock import sock
@ -35,6 +36,7 @@ def stream_with_model(ws, model_name=None):
def do_stream(ws, model_name):
event_id = None
try:
def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({
@ -46,6 +48,7 @@ def do_stream(ws, model_name):
'event': 'stream_end',
'message_num': 1
}))
ws.close()
log_to_db(ip=handler.client_ip,
token=handler.token,
prompt=input_prompt,
@ -55,7 +58,7 @@ def do_stream(ws, model_name):
headers=r_headers,
backend_response_code=response_status_code,
request_url=r_url,
backend_url=handler.cluster_backend_info,
backend_url=handler.backend_url,
response_tokens=None,
is_error=True
)
@ -74,6 +77,7 @@ def do_stream(ws, model_name):
if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400
else:
# We have to do auth ourselves since the details are sent in the message.
auth_failure = require_api_key(request_json_body)
if auth_failure:
return auth_failure
@ -89,14 +93,10 @@ def do_stream(ws, model_name):
}))
return
assert not handler.offline
if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends
raise NotImplementedError
event_id = None
generated_text = ''
input_prompt = request_json_body['prompt']
response_status_code = 0
start_time = time.time()
@ -113,119 +113,55 @@ def do_stream(ws, model_name):
send_err_and_quit(err_msg)
return
llm_request = {
**handler.parameters,
'prompt': input_prompt,
'stream': True,
handler.parameters, _ = handler.get_parameters()
handler.prompt = input_prompt
handler.request_json_body = {
'prompt': handler.prompt,
**handler.parameters
}
event = None
if not handler.is_client_ratelimited():
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
if not event:
log_to_db(
handler.client_ip,
handler.token,
handler.request_json_body.get('prompt'),
None,
None,
handler.parameters,
request.headers,
response_status_code,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
# Wait for permission to begin.
event_id = event.event_id
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
for item in pubsub.listen():
if item['type'] == 'message':
msg = item['data'].decode('utf-8')
if msg == 'begin':
break
elif msg == 'offline':
return messages.BACKEND_OFFLINE, 404 # TODO: format this error
time.sleep(0.1)
# Double check the model is still online
if not handler.check_online():
return messages.BACKEND_OFFLINE, 404 # TODO: format this error
try:
response = generator(llm_request, handler.backend_url)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': msg
}))
else:
# Be extra careful when getting attributes from the response object
try:
response_status_code = response.status_code
except:
response_status_code = 0
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
try:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
except:
# The has client closed the stream.
if response:
# Cancel the backend?
response.close()
# used to log here
r = handler.handle_ratelimited()
send_err_and_quit(r[0].data)
return
event_id = event.event_id
message_num += 1
partial_response = b'' # Reset the partial response
stream_name = event.wait()
stream_redis = Redis(db=8)
generated_text = ''
# If there is no more data, break the loop
if not chunk:
break
if response:
response.close()
# used to log here
except:
traceback.print_exc()
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
try:
last_id = '0-0' # The ID of the last entry we read.
while True:
stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
if not stream_data:
print("No message received in 30 seconds, closing stream.")
return
else:
for stream_index, item in stream_data[0][1]:
last_id = stream_index
data = pickle.loads(item[b'data'])
if data['error']:
print(data['error'])
send_err_and_quit('Encountered exception while streaming.')
return
elif data['new']:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': generated_text
'text': data['new']
}))
# used to log here
message_num += 1
generated_text = generated_text + data['new']
elif data['completed']:
return
except:
send_err_and_quit('Encountered exception while streaming.')
traceback.print_exc()
finally:
if event_id:
redis.publish(event_id, 'finished')
else:
print('event_id was None!')
try:
ws.send(json.dumps({
'event': 'stream_end',
@ -234,6 +170,7 @@ def do_stream(ws, model_name):
except:
# The client closed the stream.
pass
stream_redis.delete(stream_name)
end_time = time.time()
elapsed_time = end_time - start_time
log_to_db(ip=handler.client_ip,
@ -248,6 +185,8 @@ def do_stream(ws, model_name):
backend_url=handler.backend_url
)
finally:
if event_id:
redis.lpush(f'notifications:{event_id}', 'canceled')
try:
# Must close the connection or greenlets will complain.
ws.close()

View File

@ -3,6 +3,6 @@ from flask_sock import Sock
sock = Sock()
def init_socketio(app):
def init_wssocket(app):
global sock
sock.init_app(app)

View File

@ -7,7 +7,7 @@ from uuid import uuid4
from redis import Redis
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom
from llm_server.custom_redis import RedisCustom, redis
from llm_server.llm.generator import generator
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
@ -20,15 +20,25 @@ def get_stream_name(name: str):
return f'{STREAM_NAME_PREFIX}:{name}'
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str):
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
prompt = msg_to_backend['prompt']
stream_name = get_stream_name(stream_name)
redis.delete(f'notifications:{event_id}')
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
try:
response = generator(msg_to_backend, backend_url)
generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
# If there is no more data, break the loop
if not chunk:
break
message = redis.lpop(f'notifications:{event_id}')
if message and message.decode('utf-8') == 'canceled':
print('Client canceled generation')
response.close()
return
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
@ -74,14 +84,16 @@ def worker(backend_url):
try:
if do_stream:
# Return the name of the stream that the slave should connect to.
event = DataEvent(event_id)
event.set(get_stream_name(worker_id))
msg_to_backend = {
**parameters,
'prompt': request_json_body['prompt'],
'stream': True,
}
inference_do_stream(worker_id, msg_to_backend, backend_url)
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
else:
# Normal inference (not streaming).
success, response, error_msg = generator(request_json_body, backend_url)

View File

@ -29,4 +29,4 @@ def console_printer():
# TODO: Active Workers and Processing should read the same. If not, that's an issue
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
time.sleep(10)
time.sleep(2)

View File

@ -11,6 +11,7 @@ except ImportError:
HOST = 'localhost:5000'
URI = f'ws://{HOST}/api/v1/stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
@ -82,5 +83,6 @@ async def print_response_stream(prompt):
if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)"
# prompt = "In order to make homemade bread, follow these steps:\n1)"
prompt = "Write a 300 word description of how an apple tree grows.\n\n"
asyncio.run(print_response_stream(prompt))

View File

@ -28,7 +28,7 @@ from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.sock import init_socketio
from llm_server.sock import init_wssocket
# TODO: queue item timeout
# TODO: return an `error: True`, error code, and error message rather than just a formatted message
@ -68,10 +68,15 @@ except ModuleNotFoundError as e:
sys.exit(1)
app = Flask(__name__)
# Fixes ConcurrentObjectUseError
# https://github.com/miguelgrinberg/simple-websocket/issues/24
app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
app.register_blueprint(bp, url_prefix='/api/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
init_socketio(app)
init_wssocket(app)
flask_cache.init_app(app)
flask_cache.clear()