get streaming working again
This commit is contained in:
parent
151b3e4769
commit
2c7773cc4f
|
@ -223,5 +223,14 @@ class RedisCustom(Redis):
|
|||
self.flush()
|
||||
return True
|
||||
|
||||
def lrange(self, name: str, start: int, end: int):
|
||||
return self.redis.lrange(self._key(name), start, end)
|
||||
|
||||
def delete(self, *names: KeyT):
|
||||
return self.redis.delete(*[self._key(i) for i in names])
|
||||
|
||||
def lpop(self, name: str, count: Optional[int] = None):
|
||||
return self.redis.lpop(self._key(name), count)
|
||||
|
||||
|
||||
redis = RedisCustom('local_llm')
|
||||
|
|
|
@ -5,6 +5,9 @@ from redis import Redis
|
|||
|
||||
|
||||
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
|
||||
assert isinstance(prompt, str)
|
||||
assert isinstance(backend_url, str)
|
||||
|
||||
r = Redis(host='localhost', port=6379, db=3)
|
||||
data = {
|
||||
'function': 'log_prompt',
|
||||
|
|
|
@ -1,52 +0,0 @@
|
|||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import requests
|
||||
|
||||
from llm_server import opts
|
||||
|
||||
|
||||
def get_power_states():
|
||||
gpu_num = 0
|
||||
output = {}
|
||||
while True:
|
||||
url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state"
|
||||
try:
|
||||
response = requests.get(url, timeout=10)
|
||||
if response.status_code != 200:
|
||||
break
|
||||
data = json.loads(response.text)
|
||||
power_state_data = data['data'][0]
|
||||
power_state = None
|
||||
for i in range(1, len(power_state_data)):
|
||||
if power_state_data[i] == 1:
|
||||
power_state = data['labels'][i]
|
||||
break
|
||||
output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p'))
|
||||
except Exception as e:
|
||||
print('Failed to fetch Netdata metrics:', e)
|
||||
return output
|
||||
gpu_num += 1
|
||||
return output
|
||||
|
||||
|
||||
def get_gpu_wh(gpu_id: int):
|
||||
chart_name = f"nvidia_smi.gpu{gpu_id}_power"
|
||||
now = datetime.now()
|
||||
one_hour_ago = now - timedelta(hours=1)
|
||||
num_seconds = int((now - one_hour_ago).total_seconds())
|
||||
params = {
|
||||
"chart": chart_name,
|
||||
"after": int(one_hour_ago.timestamp()),
|
||||
"before": int(now.timestamp()),
|
||||
"points": num_seconds,
|
||||
"group": "second",
|
||||
"format": "json",
|
||||
"options": "absolute|jsonwrap"
|
||||
}
|
||||
response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10)
|
||||
data = json.loads(response.text)
|
||||
total_power_usage_watts = sum(point[1] for point in data['result']['data'])
|
||||
# total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1)
|
||||
total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3)
|
||||
return total_power_usage_kwh
|
|
@ -43,24 +43,23 @@ def openai_chat_completions(model_name=None):
|
|||
if not opts.enable_streaming:
|
||||
return 'Streaming disabled', 403
|
||||
|
||||
handler.parameters, _ = handler.get_parameters()
|
||||
handler.request_json_body = {
|
||||
'messages': handler.request_json_body['messages'],
|
||||
'model': handler.request_json_body['model'],
|
||||
**handler.parameters
|
||||
}
|
||||
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
|
||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
handler.parameters, e = handler.get_parameters()
|
||||
handler.request_json_body = {
|
||||
'messages': handler.request_json_body['messages'],
|
||||
'model': handler.request_json_body['model'],
|
||||
**handler.parameters
|
||||
}
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
|
||||
else:
|
||||
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
|
||||
|
||||
if not handler.prompt:
|
||||
# Prevent issues on the backend.
|
||||
return 'Invalid prompt', 400
|
||||
|
@ -73,90 +72,94 @@ def openai_chat_completions(model_name=None):
|
|||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
else:
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||
if not event:
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
None,
|
||||
None,
|
||||
handler.parameters,
|
||||
request.headers,
|
||||
429,
|
||||
request.url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
try:
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||
if not event:
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
None,
|
||||
None,
|
||||
handler.parameters,
|
||||
request.headers,
|
||||
429,
|
||||
request.url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
def generate():
|
||||
stream_name = event.wait()
|
||||
stream_redis = Redis(db=8)
|
||||
generated_text = ''
|
||||
try:
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: '0-0'}, block=30000)
|
||||
if not stream_data:
|
||||
print("No message received in 30 seconds, closing stream.")
|
||||
yield 'data: [DONE]\n\n'
|
||||
else:
|
||||
for r_timestamp, item in stream_data[0][1]:
|
||||
timestamp = int(r_timestamp.decode('utf-8').split('-')[0])
|
||||
data = pickle.loads(item[b'data'])
|
||||
if data['error']:
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
elif data['new']:
|
||||
response = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": timestamp,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": data['new']
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
generated_text = generated_text + data['new']
|
||||
yield f'data: {json.dumps(response)}\n\n'
|
||||
elif data['completed']:
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
200,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return
|
||||
except (Exception, GeneratorExit):
|
||||
traceback.print_exc()
|
||||
yield 'data: [DONE]\n\n'
|
||||
finally:
|
||||
stream_redis.delete(stream_name)
|
||||
try:
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'INTERNAL SERVER', 500
|
||||
def generate():
|
||||
stream_name = event.wait()
|
||||
stream_redis = Redis(db=8)
|
||||
generated_text = ''
|
||||
try:
|
||||
last_id = '0-0'
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
|
||||
if not stream_data:
|
||||
print("No message received in 30 seconds, closing stream.")
|
||||
yield 'data: [DONE]\n\n'
|
||||
else:
|
||||
for stream_index, item in stream_data[0][1]:
|
||||
last_id = stream_index
|
||||
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
||||
data = pickle.loads(item[b'data'])
|
||||
if data['error']:
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
elif data['new']:
|
||||
response = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": timestamp,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": data['new']
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
generated_text = generated_text + data['new']
|
||||
yield f'data: {json.dumps(response)}\n\n'
|
||||
elif data['completed']:
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
200,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return
|
||||
except (Exception, GeneratorExit):
|
||||
traceback.print_exc()
|
||||
yield 'data: [DONE]\n\n'
|
||||
finally:
|
||||
if event:
|
||||
redis.lpush(f'notifications:{event.event_id}', 'canceled')
|
||||
stream_redis.delete(stream_name)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'INTERNAL SERVER', 500
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import pickle
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import simplejson as json
|
||||
from flask import Response, jsonify, request
|
||||
from redis import Redis
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
from . import openai_bp, openai_model_bp
|
||||
|
@ -12,7 +14,6 @@ from ..queue import priority_queue
|
|||
from ... import opts
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm import get_token_count
|
||||
from ...llm.generator import generator
|
||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
||||
|
||||
|
@ -42,12 +43,14 @@ def openai_completions(model_name=None):
|
|||
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||
else:
|
||||
# The handle_request() call below will load the prompt so we don't have
|
||||
# to do anything else here.
|
||||
pass
|
||||
|
||||
handler.request_json_body['prompt'] = handler.prompt
|
||||
|
||||
if not request_json_body.get('stream'):
|
||||
invalid_oai_err_msg = validate_oai(request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
|
@ -89,120 +92,120 @@ def openai_completions(model_name=None):
|
|||
if not opts.enable_streaming:
|
||||
return 'Streaming disabled', 403
|
||||
|
||||
event_id = None
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
|
||||
handler.parameters, _ = handler.get_parameters()
|
||||
handler.request_json_body = {
|
||||
'prompt': handler.request_json_body['prompt'],
|
||||
'model': handler.request_json_body['model'],
|
||||
**handler.parameters
|
||||
}
|
||||
|
||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||
if invalid_oai_err_msg:
|
||||
return invalid_oai_err_msg
|
||||
|
||||
if opts.openai_silent_trim:
|
||||
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
|
||||
if not handler.prompt:
|
||||
# Prevent issues on the backend.
|
||||
return 'Invalid prompt', 400
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
request_valid, invalid_response = handler.validate_request()
|
||||
if not request_valid:
|
||||
return invalid_response
|
||||
else:
|
||||
handler.prompt = handler.request_json_body['prompt']
|
||||
msg_to_backend = {
|
||||
**handler.parameters,
|
||||
'prompt': handler.prompt,
|
||||
'stream': True,
|
||||
}
|
||||
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
|
||||
if not event:
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
None,
|
||||
None,
|
||||
handler.parameters,
|
||||
request.headers,
|
||||
429,
|
||||
request.url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return handler.handle_ratelimited()
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||
if not event:
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
None,
|
||||
None,
|
||||
handler.parameters,
|
||||
request.headers,
|
||||
429,
|
||||
request.url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
# Wait for permission to begin.
|
||||
event_id = event.event_id
|
||||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(event_id)
|
||||
for item in pubsub.listen():
|
||||
if item['type'] == 'message':
|
||||
msg = item['data'].decode('utf-8')
|
||||
if msg == 'begin':
|
||||
break
|
||||
elif msg == 'offline':
|
||||
return return_invalid_model_err(handler.request_json_body['model'])
|
||||
time.sleep(0.1)
|
||||
try:
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
# Double check the model is still online
|
||||
if not handler.check_online():
|
||||
return return_invalid_model_err(handler.request_json_body['model'])
|
||||
|
||||
try:
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
r_headers = dict(request.headers)
|
||||
r_url = request.url
|
||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||
oai_string = generate_oai_string(30)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
data = {
|
||||
"id": f"cmpl-{oai_string}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
200,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
finally:
|
||||
if event_id:
|
||||
redis.publish(event_id, 'finished')
|
||||
def generate():
|
||||
stream_name = event.wait()
|
||||
stream_redis = Redis(db=8)
|
||||
generated_text = ''
|
||||
try:
|
||||
last_id = '0-0'
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
|
||||
if not stream_data:
|
||||
print("No message received in 30 seconds, closing stream.")
|
||||
yield 'data: [DONE]\n\n'
|
||||
else:
|
||||
print('event_id was None!')
|
||||
for stream_index, item in stream_data[0][1]:
|
||||
last_id = stream_index
|
||||
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
||||
data = pickle.loads(item[b'data'])
|
||||
if data['error']:
|
||||
yield 'data: [DONE]\n\n'
|
||||
return
|
||||
elif data['new']:
|
||||
response = {
|
||||
"id": f"cmpl-{oai_string}",
|
||||
"object": "text_completion",
|
||||
"created": timestamp,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": data['new']
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
generated_text = generated_text + data['new']
|
||||
yield f'data: {json.dumps(response)}\n\n'
|
||||
elif data['completed']:
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
200,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return
|
||||
except (Exception, GeneratorExit):
|
||||
traceback.print_exc()
|
||||
yield 'data: [DONE]\n\n'
|
||||
finally:
|
||||
if event:
|
||||
redis.lpush(f'notifications:{event.event_id}', 'canceled')
|
||||
stream_redis.delete(stream_name)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'INTERNAL SERVER', 500
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'INTERNAL SERVER', 500
|
||||
|
|
|
@ -150,10 +150,6 @@ class OpenAIRequestHandler(RequestHandler):
|
|||
"total_tokens": prompt_tokens + response_tokens
|
||||
}
|
||||
}), 200)
|
||||
|
||||
stats = redis.get('proxy_stats', dtype=dict)
|
||||
if stats:
|
||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
||||
return response
|
||||
|
||||
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||
|
|
|
@ -37,6 +37,9 @@ class RequestHandler:
|
|||
self.parameters = None
|
||||
self.used = False
|
||||
|
||||
# This is null by default since most handlers need to transform the prompt in a specific way.
|
||||
self.prompt = None
|
||||
|
||||
self.selected_model = selected_model
|
||||
self.backend_url = get_a_cluster_backend(selected_model)
|
||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||
|
|
|
@ -1,17 +1,18 @@
|
|||
import json
|
||||
import pickle
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from flask import request
|
||||
from redis import Redis
|
||||
|
||||
from . import bp
|
||||
from ..helpers.http import require_api_key, validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import priority_queue
|
||||
from ... import messages, opts
|
||||
from ... import opts
|
||||
from ...custom_redis import redis
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.generator import generator
|
||||
from ...sock import sock
|
||||
|
||||
|
||||
|
@ -35,6 +36,7 @@ def stream_with_model(ws, model_name=None):
|
|||
|
||||
|
||||
def do_stream(ws, model_name):
|
||||
event_id = None
|
||||
try:
|
||||
def send_err_and_quit(quitting_err_msg):
|
||||
ws.send(json.dumps({
|
||||
|
@ -46,6 +48,7 @@ def do_stream(ws, model_name):
|
|||
'event': 'stream_end',
|
||||
'message_num': 1
|
||||
}))
|
||||
ws.close()
|
||||
log_to_db(ip=handler.client_ip,
|
||||
token=handler.token,
|
||||
prompt=input_prompt,
|
||||
|
@ -55,7 +58,7 @@ def do_stream(ws, model_name):
|
|||
headers=r_headers,
|
||||
backend_response_code=response_status_code,
|
||||
request_url=r_url,
|
||||
backend_url=handler.cluster_backend_info,
|
||||
backend_url=handler.backend_url,
|
||||
response_tokens=None,
|
||||
is_error=True
|
||||
)
|
||||
|
@ -74,6 +77,7 @@ def do_stream(ws, model_name):
|
|||
if not request_valid_json or not request_json_body.get('prompt'):
|
||||
return 'Invalid JSON', 400
|
||||
else:
|
||||
# We have to do auth ourselves since the details are sent in the message.
|
||||
auth_failure = require_api_key(request_json_body)
|
||||
if auth_failure:
|
||||
return auth_failure
|
||||
|
@ -89,14 +93,10 @@ def do_stream(ws, model_name):
|
|||
}))
|
||||
return
|
||||
|
||||
assert not handler.offline
|
||||
|
||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
event_id = None
|
||||
generated_text = ''
|
||||
input_prompt = request_json_body['prompt']
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
|
@ -113,119 +113,55 @@ def do_stream(ws, model_name):
|
|||
send_err_and_quit(err_msg)
|
||||
return
|
||||
|
||||
llm_request = {
|
||||
**handler.parameters,
|
||||
'prompt': input_prompt,
|
||||
'stream': True,
|
||||
handler.parameters, _ = handler.get_parameters()
|
||||
handler.prompt = input_prompt
|
||||
handler.request_json_body = {
|
||||
'prompt': handler.prompt,
|
||||
**handler.parameters
|
||||
}
|
||||
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
|
||||
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||
if not event:
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.request_json_body.get('prompt'),
|
||||
None,
|
||||
None,
|
||||
handler.parameters,
|
||||
request.headers,
|
||||
response_status_code,
|
||||
request.url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
# Wait for permission to begin.
|
||||
r = handler.handle_ratelimited()
|
||||
send_err_and_quit(r[0].data)
|
||||
return
|
||||
event_id = event.event_id
|
||||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(event_id)
|
||||
for item in pubsub.listen():
|
||||
if item['type'] == 'message':
|
||||
msg = item['data'].decode('utf-8')
|
||||
if msg == 'begin':
|
||||
break
|
||||
elif msg == 'offline':
|
||||
return messages.BACKEND_OFFLINE, 404 # TODO: format this error
|
||||
time.sleep(0.1)
|
||||
|
||||
# Double check the model is still online
|
||||
if not handler.check_online():
|
||||
return messages.BACKEND_OFFLINE, 404 # TODO: format this error
|
||||
stream_name = event.wait()
|
||||
stream_redis = Redis(db=8)
|
||||
generated_text = ''
|
||||
|
||||
try:
|
||||
response = generator(llm_request, handler.backend_url)
|
||||
if not response:
|
||||
error_msg = 'Failed to reach backend while streaming.'
|
||||
print('Streaming failed:', error_msg)
|
||||
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': msg
|
||||
}))
|
||||
else:
|
||||
# Be extra careful when getting attributes from the response object
|
||||
try:
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
|
||||
partial_response = b''
|
||||
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': new
|
||||
}))
|
||||
except:
|
||||
# The has client closed the stream.
|
||||
if response:
|
||||
# Cancel the backend?
|
||||
response.close()
|
||||
# used to log here
|
||||
return
|
||||
|
||||
message_num += 1
|
||||
partial_response = b'' # Reset the partial response
|
||||
|
||||
# If there is no more data, break the loop
|
||||
if not chunk:
|
||||
break
|
||||
if response:
|
||||
response.close()
|
||||
# used to log here
|
||||
last_id = '0-0' # The ID of the last entry we read.
|
||||
while True:
|
||||
stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
|
||||
if not stream_data:
|
||||
print("No message received in 30 seconds, closing stream.")
|
||||
return
|
||||
else:
|
||||
for stream_index, item in stream_data[0][1]:
|
||||
last_id = stream_index
|
||||
data = pickle.loads(item[b'data'])
|
||||
if data['error']:
|
||||
print(data['error'])
|
||||
send_err_and_quit('Encountered exception while streaming.')
|
||||
return
|
||||
elif data['new']:
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': data['new']
|
||||
}))
|
||||
message_num += 1
|
||||
generated_text = generated_text + data['new']
|
||||
elif data['completed']:
|
||||
return
|
||||
except:
|
||||
send_err_and_quit('Encountered exception while streaming.')
|
||||
traceback.print_exc()
|
||||
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
|
||||
ws.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': generated_text
|
||||
}))
|
||||
# used to log here
|
||||
finally:
|
||||
if event_id:
|
||||
redis.publish(event_id, 'finished')
|
||||
else:
|
||||
print('event_id was None!')
|
||||
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
|
@ -234,6 +170,7 @@ def do_stream(ws, model_name):
|
|||
except:
|
||||
# The client closed the stream.
|
||||
pass
|
||||
stream_redis.delete(stream_name)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(ip=handler.client_ip,
|
||||
|
@ -248,6 +185,8 @@ def do_stream(ws, model_name):
|
|||
backend_url=handler.backend_url
|
||||
)
|
||||
finally:
|
||||
if event_id:
|
||||
redis.lpush(f'notifications:{event_id}', 'canceled')
|
||||
try:
|
||||
# Must close the connection or greenlets will complain.
|
||||
ws.close()
|
||||
|
|
|
@ -3,6 +3,6 @@ from flask_sock import Sock
|
|||
sock = Sock()
|
||||
|
||||
|
||||
def init_socketio(app):
|
||||
def init_wssocket(app):
|
||||
global sock
|
||||
sock.init_app(app)
|
||||
|
|
|
@ -7,7 +7,7 @@ from uuid import uuid4
|
|||
from redis import Redis
|
||||
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import RedisCustom
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
||||
|
||||
|
@ -20,15 +20,25 @@ def get_stream_name(name: str):
|
|||
return f'{STREAM_NAME_PREFIX}:{name}'
|
||||
|
||||
|
||||
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str):
|
||||
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
|
||||
prompt = msg_to_backend['prompt']
|
||||
stream_name = get_stream_name(stream_name)
|
||||
redis.delete(f'notifications:{event_id}')
|
||||
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
|
||||
try:
|
||||
response = generator(msg_to_backend, backend_url)
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
# If there is no more data, break the loop
|
||||
if not chunk:
|
||||
break
|
||||
message = redis.lpop(f'notifications:{event_id}')
|
||||
if message and message.decode('utf-8') == 'canceled':
|
||||
print('Client canceled generation')
|
||||
response.close()
|
||||
return
|
||||
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
|
@ -74,14 +84,16 @@ def worker(backend_url):
|
|||
|
||||
try:
|
||||
if do_stream:
|
||||
# Return the name of the stream that the slave should connect to.
|
||||
event = DataEvent(event_id)
|
||||
event.set(get_stream_name(worker_id))
|
||||
|
||||
msg_to_backend = {
|
||||
**parameters,
|
||||
'prompt': request_json_body['prompt'],
|
||||
'stream': True,
|
||||
}
|
||||
inference_do_stream(worker_id, msg_to_backend, backend_url)
|
||||
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
|
||||
else:
|
||||
# Normal inference (not streaming).
|
||||
success, response, error_msg = generator(request_json_body, backend_url)
|
||||
|
|
|
@ -29,4 +29,4 @@ def console_printer():
|
|||
# TODO: Active Workers and Processing should read the same. If not, that's an issue
|
||||
|
||||
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
|
||||
time.sleep(10)
|
||||
time.sleep(2)
|
||||
|
|
|
@ -11,6 +11,7 @@ except ImportError:
|
|||
HOST = 'localhost:5000'
|
||||
URI = f'ws://{HOST}/api/v1/stream'
|
||||
|
||||
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||
|
||||
|
@ -82,5 +83,6 @@ async def print_response_stream(prompt):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||
# prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||
prompt = "Write a 300 word description of how an apple tree grows.\n\n"
|
||||
asyncio.run(print_response_stream(prompt))
|
||||
|
|
|
@ -28,7 +28,7 @@ from llm_server.routes.openai import openai_bp, openai_model_bp
|
|||
from llm_server.routes.server_error import handle_server_error
|
||||
from llm_server.routes.v1 import bp
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.sock import init_socketio
|
||||
from llm_server.sock import init_wssocket
|
||||
|
||||
# TODO: queue item timeout
|
||||
# TODO: return an `error: True`, error code, and error message rather than just a formatted message
|
||||
|
@ -68,10 +68,15 @@ except ModuleNotFoundError as e:
|
|||
sys.exit(1)
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Fixes ConcurrentObjectUseError
|
||||
# https://github.com/miguelgrinberg/simple-websocket/issues/24
|
||||
app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
|
||||
|
||||
app.register_blueprint(bp, url_prefix='/api/')
|
||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
|
||||
init_socketio(app)
|
||||
init_wssocket(app)
|
||||
flask_cache.init_app(app)
|
||||
flask_cache.clear()
|
||||
|
||||
|
|
Reference in New Issue