get streaming working again
This commit is contained in:
parent
151b3e4769
commit
2c7773cc4f
|
@ -223,5 +223,14 @@ class RedisCustom(Redis):
|
||||||
self.flush()
|
self.flush()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def lrange(self, name: str, start: int, end: int):
|
||||||
|
return self.redis.lrange(self._key(name), start, end)
|
||||||
|
|
||||||
|
def delete(self, *names: KeyT):
|
||||||
|
return self.redis.delete(*[self._key(i) for i in names])
|
||||||
|
|
||||||
|
def lpop(self, name: str, count: Optional[int] = None):
|
||||||
|
return self.redis.lpop(self._key(name), count)
|
||||||
|
|
||||||
|
|
||||||
redis = RedisCustom('local_llm')
|
redis = RedisCustom('local_llm')
|
||||||
|
|
|
@ -5,6 +5,9 @@ from redis import Redis
|
||||||
|
|
||||||
|
|
||||||
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
|
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
|
||||||
|
assert isinstance(prompt, str)
|
||||||
|
assert isinstance(backend_url, str)
|
||||||
|
|
||||||
r = Redis(host='localhost', port=6379, db=3)
|
r = Redis(host='localhost', port=6379, db=3)
|
||||||
data = {
|
data = {
|
||||||
'function': 'log_prompt',
|
'function': 'log_prompt',
|
||||||
|
|
|
@ -1,52 +0,0 @@
|
||||||
import json
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from llm_server import opts
|
|
||||||
|
|
||||||
|
|
||||||
def get_power_states():
|
|
||||||
gpu_num = 0
|
|
||||||
output = {}
|
|
||||||
while True:
|
|
||||||
url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state"
|
|
||||||
try:
|
|
||||||
response = requests.get(url, timeout=10)
|
|
||||||
if response.status_code != 200:
|
|
||||||
break
|
|
||||||
data = json.loads(response.text)
|
|
||||||
power_state_data = data['data'][0]
|
|
||||||
power_state = None
|
|
||||||
for i in range(1, len(power_state_data)):
|
|
||||||
if power_state_data[i] == 1:
|
|
||||||
power_state = data['labels'][i]
|
|
||||||
break
|
|
||||||
output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p'))
|
|
||||||
except Exception as e:
|
|
||||||
print('Failed to fetch Netdata metrics:', e)
|
|
||||||
return output
|
|
||||||
gpu_num += 1
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_wh(gpu_id: int):
|
|
||||||
chart_name = f"nvidia_smi.gpu{gpu_id}_power"
|
|
||||||
now = datetime.now()
|
|
||||||
one_hour_ago = now - timedelta(hours=1)
|
|
||||||
num_seconds = int((now - one_hour_ago).total_seconds())
|
|
||||||
params = {
|
|
||||||
"chart": chart_name,
|
|
||||||
"after": int(one_hour_ago.timestamp()),
|
|
||||||
"before": int(now.timestamp()),
|
|
||||||
"points": num_seconds,
|
|
||||||
"group": "second",
|
|
||||||
"format": "json",
|
|
||||||
"options": "absolute|jsonwrap"
|
|
||||||
}
|
|
||||||
response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10)
|
|
||||||
data = json.loads(response.text)
|
|
||||||
total_power_usage_watts = sum(point[1] for point in data['result']['data'])
|
|
||||||
# total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1)
|
|
||||||
total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3)
|
|
||||||
return total_power_usage_kwh
|
|
|
@ -43,24 +43,23 @@ def openai_chat_completions(model_name=None):
|
||||||
if not opts.enable_streaming:
|
if not opts.enable_streaming:
|
||||||
return 'Streaming disabled', 403
|
return 'Streaming disabled', 403
|
||||||
|
|
||||||
handler.parameters, _ = handler.get_parameters()
|
|
||||||
handler.request_json_body = {
|
|
||||||
'messages': handler.request_json_body['messages'],
|
|
||||||
'model': handler.request_json_body['model'],
|
|
||||||
**handler.parameters
|
|
||||||
}
|
|
||||||
|
|
||||||
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||||
if invalid_oai_err_msg:
|
if invalid_oai_err_msg:
|
||||||
return invalid_oai_err_msg
|
return invalid_oai_err_msg
|
||||||
|
|
||||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
|
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
|
||||||
|
|
||||||
|
handler.parameters, e = handler.get_parameters()
|
||||||
|
handler.request_json_body = {
|
||||||
|
'messages': handler.request_json_body['messages'],
|
||||||
|
'model': handler.request_json_body['model'],
|
||||||
|
**handler.parameters
|
||||||
|
}
|
||||||
|
|
||||||
if opts.openai_silent_trim:
|
if opts.openai_silent_trim:
|
||||||
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
|
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
|
||||||
else:
|
else:
|
||||||
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
|
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
|
||||||
|
|
||||||
if not handler.prompt:
|
if not handler.prompt:
|
||||||
# Prevent issues on the backend.
|
# Prevent issues on the backend.
|
||||||
return 'Invalid prompt', 400
|
return 'Invalid prompt', 400
|
||||||
|
@ -73,7 +72,7 @@ def openai_chat_completions(model_name=None):
|
||||||
request_valid, invalid_response = handler.validate_request()
|
request_valid, invalid_response = handler.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
return invalid_response
|
return invalid_response
|
||||||
else:
|
|
||||||
event = None
|
event = None
|
||||||
if not handler.is_client_ratelimited():
|
if not handler.is_client_ratelimited():
|
||||||
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||||
|
@ -103,14 +102,16 @@ def openai_chat_completions(model_name=None):
|
||||||
stream_redis = Redis(db=8)
|
stream_redis = Redis(db=8)
|
||||||
generated_text = ''
|
generated_text = ''
|
||||||
try:
|
try:
|
||||||
|
last_id = '0-0'
|
||||||
while True:
|
while True:
|
||||||
stream_data = stream_redis.xread({stream_name: '0-0'}, block=30000)
|
stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
|
||||||
if not stream_data:
|
if not stream_data:
|
||||||
print("No message received in 30 seconds, closing stream.")
|
print("No message received in 30 seconds, closing stream.")
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
else:
|
else:
|
||||||
for r_timestamp, item in stream_data[0][1]:
|
for stream_index, item in stream_data[0][1]:
|
||||||
timestamp = int(r_timestamp.decode('utf-8').split('-')[0])
|
last_id = stream_index
|
||||||
|
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
||||||
data = pickle.loads(item[b'data'])
|
data = pickle.loads(item[b'data'])
|
||||||
if data['error']:
|
if data['error']:
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
|
@ -154,6 +155,8 @@ def openai_chat_completions(model_name=None):
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
finally:
|
finally:
|
||||||
|
if event:
|
||||||
|
redis.lpush(f'notifications:{event.event_id}', 'canceled')
|
||||||
stream_redis.delete(stream_name)
|
stream_redis.delete(stream_name)
|
||||||
|
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
|
import pickle
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask import Response, jsonify, request
|
from flask import Response, jsonify, request
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
from llm_server.custom_redis import redis
|
from llm_server.custom_redis import redis
|
||||||
from . import openai_bp, openai_model_bp
|
from . import openai_bp, openai_model_bp
|
||||||
|
@ -12,7 +14,6 @@ from ..queue import priority_queue
|
||||||
from ... import opts
|
from ... import opts
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm import get_token_count
|
from ...llm import get_token_count
|
||||||
from ...llm.generator import generator
|
|
||||||
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
|
||||||
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
|
||||||
|
|
||||||
|
@ -42,12 +43,14 @@ def openai_completions(model_name=None):
|
||||||
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
|
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
|
||||||
|
|
||||||
if opts.openai_silent_trim:
|
if opts.openai_silent_trim:
|
||||||
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
|
||||||
else:
|
else:
|
||||||
# The handle_request() call below will load the prompt so we don't have
|
# The handle_request() call below will load the prompt so we don't have
|
||||||
# to do anything else here.
|
# to do anything else here.
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
handler.request_json_body['prompt'] = handler.prompt
|
||||||
|
|
||||||
if not request_json_body.get('stream'):
|
if not request_json_body.get('stream'):
|
||||||
invalid_oai_err_msg = validate_oai(request_json_body)
|
invalid_oai_err_msg = validate_oai(request_json_body)
|
||||||
if invalid_oai_err_msg:
|
if invalid_oai_err_msg:
|
||||||
|
@ -89,24 +92,36 @@ def openai_completions(model_name=None):
|
||||||
if not opts.enable_streaming:
|
if not opts.enable_streaming:
|
||||||
return 'Streaming disabled', 403
|
return 'Streaming disabled', 403
|
||||||
|
|
||||||
event_id = None
|
request_valid, invalid_response = handler.validate_request()
|
||||||
|
if not request_valid:
|
||||||
|
return invalid_response
|
||||||
|
|
||||||
|
handler.parameters, _ = handler.get_parameters()
|
||||||
|
handler.request_json_body = {
|
||||||
|
'prompt': handler.request_json_body['prompt'],
|
||||||
|
'model': handler.request_json_body['model'],
|
||||||
|
**handler.parameters
|
||||||
|
}
|
||||||
|
|
||||||
|
invalid_oai_err_msg = validate_oai(handler.request_json_body)
|
||||||
|
if invalid_oai_err_msg:
|
||||||
|
return invalid_oai_err_msg
|
||||||
|
|
||||||
|
if opts.openai_silent_trim:
|
||||||
|
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
|
||||||
|
if not handler.prompt:
|
||||||
|
# Prevent issues on the backend.
|
||||||
|
return 'Invalid prompt', 400
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
request_valid, invalid_response = handler.validate_request()
|
request_valid, invalid_response = handler.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
return invalid_response
|
return invalid_response
|
||||||
else:
|
|
||||||
handler.prompt = handler.request_json_body['prompt']
|
|
||||||
msg_to_backend = {
|
|
||||||
**handler.parameters,
|
|
||||||
'prompt': handler.prompt,
|
|
||||||
'stream': True,
|
|
||||||
}
|
|
||||||
|
|
||||||
event = None
|
event = None
|
||||||
if not handler.is_client_ratelimited():
|
if not handler.is_client_ratelimited():
|
||||||
# Add a dummy event to the queue and wait for it to reach a worker
|
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||||
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
|
|
||||||
if not event:
|
if not event:
|
||||||
log_to_db(
|
log_to_db(
|
||||||
handler.client_ip,
|
handler.client_ip,
|
||||||
|
@ -122,68 +137,53 @@ def openai_completions(model_name=None):
|
||||||
)
|
)
|
||||||
return handler.handle_ratelimited()
|
return handler.handle_ratelimited()
|
||||||
|
|
||||||
# Wait for permission to begin.
|
|
||||||
event_id = event.event_id
|
|
||||||
pubsub = redis.pubsub()
|
|
||||||
pubsub.subscribe(event_id)
|
|
||||||
for item in pubsub.listen():
|
|
||||||
if item['type'] == 'message':
|
|
||||||
msg = item['data'].decode('utf-8')
|
|
||||||
if msg == 'begin':
|
|
||||||
break
|
|
||||||
elif msg == 'offline':
|
|
||||||
return return_invalid_model_err(handler.request_json_body['model'])
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
# Double check the model is still online
|
|
||||||
if not handler.check_online():
|
|
||||||
return return_invalid_model_err(handler.request_json_body['model'])
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = generator(msg_to_backend, handler.backend_url)
|
|
||||||
r_headers = dict(request.headers)
|
r_headers = dict(request.headers)
|
||||||
r_url = request.url
|
r_url = request.url
|
||||||
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
|
||||||
oai_string = generate_oai_string(30)
|
oai_string = generate_oai_string(30)
|
||||||
|
|
||||||
def generate():
|
def generate():
|
||||||
try:
|
stream_name = event.wait()
|
||||||
|
stream_redis = Redis(db=8)
|
||||||
generated_text = ''
|
generated_text = ''
|
||||||
partial_response = b''
|
|
||||||
for chunk in response.iter_content(chunk_size=1):
|
|
||||||
partial_response += chunk
|
|
||||||
if partial_response.endswith(b'\x00'):
|
|
||||||
json_strs = partial_response.split(b'\x00')
|
|
||||||
for json_str in json_strs:
|
|
||||||
if json_str:
|
|
||||||
try:
|
try:
|
||||||
json_obj = json.loads(json_str.decode())
|
last_id = '0-0'
|
||||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
while True:
|
||||||
generated_text = generated_text + new
|
stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
|
||||||
except IndexError:
|
if not stream_data:
|
||||||
# ????
|
print("No message received in 30 seconds, closing stream.")
|
||||||
continue
|
yield 'data: [DONE]\n\n'
|
||||||
|
else:
|
||||||
data = {
|
for stream_index, item in stream_data[0][1]:
|
||||||
|
last_id = stream_index
|
||||||
|
timestamp = int(stream_index.decode('utf-8').split('-')[0])
|
||||||
|
data = pickle.loads(item[b'data'])
|
||||||
|
if data['error']:
|
||||||
|
yield 'data: [DONE]\n\n'
|
||||||
|
return
|
||||||
|
elif data['new']:
|
||||||
|
response = {
|
||||||
"id": f"cmpl-{oai_string}",
|
"id": f"cmpl-{oai_string}",
|
||||||
"object": "text_completion",
|
"object": "text_completion",
|
||||||
"created": int(time.time()),
|
"created": timestamp,
|
||||||
"model": model,
|
"model": model,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"delta": {
|
"delta": {
|
||||||
"content": new
|
"content": data['new']
|
||||||
},
|
},
|
||||||
"finish_reason": None
|
"finish_reason": None
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
yield f'data: {json.dumps(data)}\n\n'
|
generated_text = generated_text + data['new']
|
||||||
|
yield f'data: {json.dumps(response)}\n\n'
|
||||||
|
elif data['completed']:
|
||||||
yield 'data: [DONE]\n\n'
|
yield 'data: [DONE]\n\n'
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
|
|
||||||
log_to_db(
|
log_to_db(
|
||||||
handler.client_ip,
|
handler.client_ip,
|
||||||
handler.token,
|
handler.token,
|
||||||
|
@ -196,11 +196,14 @@ def openai_completions(model_name=None):
|
||||||
r_url,
|
r_url,
|
||||||
handler.backend_url,
|
handler.backend_url,
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
except (Exception, GeneratorExit):
|
||||||
|
traceback.print_exc()
|
||||||
|
yield 'data: [DONE]\n\n'
|
||||||
finally:
|
finally:
|
||||||
if event_id:
|
if event:
|
||||||
redis.publish(event_id, 'finished')
|
redis.lpush(f'notifications:{event.event_id}', 'canceled')
|
||||||
else:
|
stream_redis.delete(stream_name)
|
||||||
print('event_id was None!')
|
|
||||||
|
|
||||||
return Response(generate(), mimetype='text/event-stream')
|
return Response(generate(), mimetype='text/event-stream')
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -150,10 +150,6 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
"total_tokens": prompt_tokens + response_tokens
|
"total_tokens": prompt_tokens + response_tokens
|
||||||
}
|
}
|
||||||
}), 200)
|
}), 200)
|
||||||
|
|
||||||
stats = redis.get('proxy_stats', dtype=dict)
|
|
||||||
if stats:
|
|
||||||
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:
|
||||||
|
|
|
@ -37,6 +37,9 @@ class RequestHandler:
|
||||||
self.parameters = None
|
self.parameters = None
|
||||||
self.used = False
|
self.used = False
|
||||||
|
|
||||||
|
# This is null by default since most handlers need to transform the prompt in a specific way.
|
||||||
|
self.prompt = None
|
||||||
|
|
||||||
self.selected_model = selected_model
|
self.selected_model = selected_model
|
||||||
self.backend_url = get_a_cluster_backend(selected_model)
|
self.backend_url = get_a_cluster_backend(selected_model)
|
||||||
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
self.cluster_backend_info = cluster_config.get_backend(self.backend_url)
|
||||||
|
|
|
@ -1,17 +1,18 @@
|
||||||
import json
|
import json
|
||||||
|
import pickle
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from flask import request
|
from flask import request
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
from . import bp
|
from . import bp
|
||||||
from ..helpers.http import require_api_key, validate_json
|
from ..helpers.http import require_api_key, validate_json
|
||||||
from ..ooba_request_handler import OobaRequestHandler
|
from ..ooba_request_handler import OobaRequestHandler
|
||||||
from ..queue import priority_queue
|
from ..queue import priority_queue
|
||||||
from ... import messages, opts
|
from ... import opts
|
||||||
from ...custom_redis import redis
|
from ...custom_redis import redis
|
||||||
from ...database.log_to_db import log_to_db
|
from ...database.log_to_db import log_to_db
|
||||||
from ...llm.generator import generator
|
|
||||||
from ...sock import sock
|
from ...sock import sock
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,6 +36,7 @@ def stream_with_model(ws, model_name=None):
|
||||||
|
|
||||||
|
|
||||||
def do_stream(ws, model_name):
|
def do_stream(ws, model_name):
|
||||||
|
event_id = None
|
||||||
try:
|
try:
|
||||||
def send_err_and_quit(quitting_err_msg):
|
def send_err_and_quit(quitting_err_msg):
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
|
@ -46,6 +48,7 @@ def do_stream(ws, model_name):
|
||||||
'event': 'stream_end',
|
'event': 'stream_end',
|
||||||
'message_num': 1
|
'message_num': 1
|
||||||
}))
|
}))
|
||||||
|
ws.close()
|
||||||
log_to_db(ip=handler.client_ip,
|
log_to_db(ip=handler.client_ip,
|
||||||
token=handler.token,
|
token=handler.token,
|
||||||
prompt=input_prompt,
|
prompt=input_prompt,
|
||||||
|
@ -55,7 +58,7 @@ def do_stream(ws, model_name):
|
||||||
headers=r_headers,
|
headers=r_headers,
|
||||||
backend_response_code=response_status_code,
|
backend_response_code=response_status_code,
|
||||||
request_url=r_url,
|
request_url=r_url,
|
||||||
backend_url=handler.cluster_backend_info,
|
backend_url=handler.backend_url,
|
||||||
response_tokens=None,
|
response_tokens=None,
|
||||||
is_error=True
|
is_error=True
|
||||||
)
|
)
|
||||||
|
@ -74,6 +77,7 @@ def do_stream(ws, model_name):
|
||||||
if not request_valid_json or not request_json_body.get('prompt'):
|
if not request_valid_json or not request_json_body.get('prompt'):
|
||||||
return 'Invalid JSON', 400
|
return 'Invalid JSON', 400
|
||||||
else:
|
else:
|
||||||
|
# We have to do auth ourselves since the details are sent in the message.
|
||||||
auth_failure = require_api_key(request_json_body)
|
auth_failure = require_api_key(request_json_body)
|
||||||
if auth_failure:
|
if auth_failure:
|
||||||
return auth_failure
|
return auth_failure
|
||||||
|
@ -89,14 +93,10 @@ def do_stream(ws, model_name):
|
||||||
}))
|
}))
|
||||||
return
|
return
|
||||||
|
|
||||||
assert not handler.offline
|
|
||||||
|
|
||||||
if handler.cluster_backend_info['mode'] != 'vllm':
|
if handler.cluster_backend_info['mode'] != 'vllm':
|
||||||
# TODO: implement other backends
|
# TODO: implement other backends
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
event_id = None
|
|
||||||
generated_text = ''
|
|
||||||
input_prompt = request_json_body['prompt']
|
input_prompt = request_json_body['prompt']
|
||||||
response_status_code = 0
|
response_status_code = 0
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
@ -113,119 +113,55 @@ def do_stream(ws, model_name):
|
||||||
send_err_and_quit(err_msg)
|
send_err_and_quit(err_msg)
|
||||||
return
|
return
|
||||||
|
|
||||||
llm_request = {
|
handler.parameters, _ = handler.get_parameters()
|
||||||
**handler.parameters,
|
handler.prompt = input_prompt
|
||||||
'prompt': input_prompt,
|
handler.request_json_body = {
|
||||||
'stream': True,
|
'prompt': handler.prompt,
|
||||||
|
**handler.parameters
|
||||||
}
|
}
|
||||||
|
|
||||||
event = None
|
event = None
|
||||||
if not handler.is_client_ratelimited():
|
if not handler.is_client_ratelimited():
|
||||||
# Add a dummy event to the queue and wait for it to reach a worker
|
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
|
||||||
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
|
|
||||||
if not event:
|
if not event:
|
||||||
log_to_db(
|
r = handler.handle_ratelimited()
|
||||||
handler.client_ip,
|
send_err_and_quit(r[0].data)
|
||||||
handler.token,
|
|
||||||
handler.request_json_body.get('prompt'),
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
handler.parameters,
|
|
||||||
request.headers,
|
|
||||||
response_status_code,
|
|
||||||
request.url,
|
|
||||||
handler.backend_url,
|
|
||||||
)
|
|
||||||
return handler.handle_ratelimited()
|
|
||||||
|
|
||||||
# Wait for permission to begin.
|
|
||||||
event_id = event.event_id
|
|
||||||
pubsub = redis.pubsub()
|
|
||||||
pubsub.subscribe(event_id)
|
|
||||||
for item in pubsub.listen():
|
|
||||||
if item['type'] == 'message':
|
|
||||||
msg = item['data'].decode('utf-8')
|
|
||||||
if msg == 'begin':
|
|
||||||
break
|
|
||||||
elif msg == 'offline':
|
|
||||||
return messages.BACKEND_OFFLINE, 404 # TODO: format this error
|
|
||||||
time.sleep(0.1)
|
|
||||||
|
|
||||||
# Double check the model is still online
|
|
||||||
if not handler.check_online():
|
|
||||||
return messages.BACKEND_OFFLINE, 404 # TODO: format this error
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = generator(llm_request, handler.backend_url)
|
|
||||||
if not response:
|
|
||||||
error_msg = 'Failed to reach backend while streaming.'
|
|
||||||
print('Streaming failed:', error_msg)
|
|
||||||
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
|
|
||||||
ws.send(json.dumps({
|
|
||||||
'event': 'text_stream',
|
|
||||||
'message_num': message_num,
|
|
||||||
'text': msg
|
|
||||||
}))
|
|
||||||
else:
|
|
||||||
# Be extra careful when getting attributes from the response object
|
|
||||||
try:
|
|
||||||
response_status_code = response.status_code
|
|
||||||
except:
|
|
||||||
response_status_code = 0
|
|
||||||
|
|
||||||
partial_response = b''
|
|
||||||
|
|
||||||
for chunk in response.iter_content(chunk_size=1):
|
|
||||||
partial_response += chunk
|
|
||||||
if partial_response.endswith(b'\x00'):
|
|
||||||
json_strs = partial_response.split(b'\x00')
|
|
||||||
for json_str in json_strs:
|
|
||||||
if json_str:
|
|
||||||
try:
|
|
||||||
json_obj = json.loads(json_str.decode())
|
|
||||||
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
|
|
||||||
generated_text = generated_text + new
|
|
||||||
except IndexError:
|
|
||||||
# ????
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
ws.send(json.dumps({
|
|
||||||
'event': 'text_stream',
|
|
||||||
'message_num': message_num,
|
|
||||||
'text': new
|
|
||||||
}))
|
|
||||||
except:
|
|
||||||
# The has client closed the stream.
|
|
||||||
if response:
|
|
||||||
# Cancel the backend?
|
|
||||||
response.close()
|
|
||||||
# used to log here
|
|
||||||
return
|
return
|
||||||
|
event_id = event.event_id
|
||||||
|
|
||||||
message_num += 1
|
stream_name = event.wait()
|
||||||
partial_response = b'' # Reset the partial response
|
stream_redis = Redis(db=8)
|
||||||
|
generated_text = ''
|
||||||
|
|
||||||
# If there is no more data, break the loop
|
try:
|
||||||
if not chunk:
|
last_id = '0-0' # The ID of the last entry we read.
|
||||||
break
|
while True:
|
||||||
if response:
|
stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
|
||||||
response.close()
|
if not stream_data:
|
||||||
# used to log here
|
print("No message received in 30 seconds, closing stream.")
|
||||||
except:
|
return
|
||||||
traceback.print_exc()
|
else:
|
||||||
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text']
|
for stream_index, item in stream_data[0][1]:
|
||||||
|
last_id = stream_index
|
||||||
|
data = pickle.loads(item[b'data'])
|
||||||
|
if data['error']:
|
||||||
|
print(data['error'])
|
||||||
|
send_err_and_quit('Encountered exception while streaming.')
|
||||||
|
return
|
||||||
|
elif data['new']:
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'text_stream',
|
'event': 'text_stream',
|
||||||
'message_num': message_num,
|
'message_num': message_num,
|
||||||
'text': generated_text
|
'text': data['new']
|
||||||
}))
|
}))
|
||||||
# used to log here
|
message_num += 1
|
||||||
|
generated_text = generated_text + data['new']
|
||||||
|
elif data['completed']:
|
||||||
|
return
|
||||||
|
except:
|
||||||
|
send_err_and_quit('Encountered exception while streaming.')
|
||||||
|
traceback.print_exc()
|
||||||
finally:
|
finally:
|
||||||
if event_id:
|
|
||||||
redis.publish(event_id, 'finished')
|
|
||||||
else:
|
|
||||||
print('event_id was None!')
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ws.send(json.dumps({
|
ws.send(json.dumps({
|
||||||
'event': 'stream_end',
|
'event': 'stream_end',
|
||||||
|
@ -234,6 +170,7 @@ def do_stream(ws, model_name):
|
||||||
except:
|
except:
|
||||||
# The client closed the stream.
|
# The client closed the stream.
|
||||||
pass
|
pass
|
||||||
|
stream_redis.delete(stream_name)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
log_to_db(ip=handler.client_ip,
|
log_to_db(ip=handler.client_ip,
|
||||||
|
@ -248,6 +185,8 @@ def do_stream(ws, model_name):
|
||||||
backend_url=handler.backend_url
|
backend_url=handler.backend_url
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
|
if event_id:
|
||||||
|
redis.lpush(f'notifications:{event_id}', 'canceled')
|
||||||
try:
|
try:
|
||||||
# Must close the connection or greenlets will complain.
|
# Must close the connection or greenlets will complain.
|
||||||
ws.close()
|
ws.close()
|
||||||
|
|
|
@ -3,6 +3,6 @@ from flask_sock import Sock
|
||||||
sock = Sock()
|
sock = Sock()
|
||||||
|
|
||||||
|
|
||||||
def init_socketio(app):
|
def init_wssocket(app):
|
||||||
global sock
|
global sock
|
||||||
sock.init_app(app)
|
sock.init_app(app)
|
||||||
|
|
|
@ -7,7 +7,7 @@ from uuid import uuid4
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
|
|
||||||
from llm_server.cluster.cluster_config import cluster_config
|
from llm_server.cluster.cluster_config import cluster_config
|
||||||
from llm_server.custom_redis import RedisCustom
|
from llm_server.custom_redis import RedisCustom, redis
|
||||||
from llm_server.llm.generator import generator
|
from llm_server.llm.generator import generator
|
||||||
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
||||||
|
|
||||||
|
@ -20,15 +20,25 @@ def get_stream_name(name: str):
|
||||||
return f'{STREAM_NAME_PREFIX}:{name}'
|
return f'{STREAM_NAME_PREFIX}:{name}'
|
||||||
|
|
||||||
|
|
||||||
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str):
|
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
|
||||||
prompt = msg_to_backend['prompt']
|
prompt = msg_to_backend['prompt']
|
||||||
stream_name = get_stream_name(stream_name)
|
stream_name = get_stream_name(stream_name)
|
||||||
|
redis.delete(f'notifications:{event_id}')
|
||||||
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
|
stream_redis.delete(get_stream_name(stream_name)) # be extra sure
|
||||||
try:
|
try:
|
||||||
response = generator(msg_to_backend, backend_url)
|
response = generator(msg_to_backend, backend_url)
|
||||||
generated_text = ''
|
generated_text = ''
|
||||||
partial_response = b''
|
partial_response = b''
|
||||||
for chunk in response.iter_content(chunk_size=1):
|
for chunk in response.iter_content(chunk_size=1):
|
||||||
|
# If there is no more data, break the loop
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
message = redis.lpop(f'notifications:{event_id}')
|
||||||
|
if message and message.decode('utf-8') == 'canceled':
|
||||||
|
print('Client canceled generation')
|
||||||
|
response.close()
|
||||||
|
return
|
||||||
|
|
||||||
partial_response += chunk
|
partial_response += chunk
|
||||||
if partial_response.endswith(b'\x00'):
|
if partial_response.endswith(b'\x00'):
|
||||||
json_strs = partial_response.split(b'\x00')
|
json_strs = partial_response.split(b'\x00')
|
||||||
|
@ -74,14 +84,16 @@ def worker(backend_url):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if do_stream:
|
if do_stream:
|
||||||
|
# Return the name of the stream that the slave should connect to.
|
||||||
event = DataEvent(event_id)
|
event = DataEvent(event_id)
|
||||||
event.set(get_stream_name(worker_id))
|
event.set(get_stream_name(worker_id))
|
||||||
|
|
||||||
msg_to_backend = {
|
msg_to_backend = {
|
||||||
**parameters,
|
**parameters,
|
||||||
'prompt': request_json_body['prompt'],
|
'prompt': request_json_body['prompt'],
|
||||||
'stream': True,
|
'stream': True,
|
||||||
}
|
}
|
||||||
inference_do_stream(worker_id, msg_to_backend, backend_url)
|
inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
|
||||||
else:
|
else:
|
||||||
# Normal inference (not streaming).
|
# Normal inference (not streaming).
|
||||||
success, response, error_msg = generator(request_json_body, backend_url)
|
success, response, error_msg = generator(request_json_body, backend_url)
|
||||||
|
|
|
@ -29,4 +29,4 @@ def console_printer():
|
||||||
# TODO: Active Workers and Processing should read the same. If not, that's an issue
|
# TODO: Active Workers and Processing should read the same. If not, that's an issue
|
||||||
|
|
||||||
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
|
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
|
||||||
time.sleep(10)
|
time.sleep(2)
|
||||||
|
|
|
@ -11,6 +11,7 @@ except ImportError:
|
||||||
HOST = 'localhost:5000'
|
HOST = 'localhost:5000'
|
||||||
URI = f'ws://{HOST}/api/v1/stream'
|
URI = f'ws://{HOST}/api/v1/stream'
|
||||||
|
|
||||||
|
|
||||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||||
|
|
||||||
|
@ -82,5 +83,6 @@ async def print_response_stream(prompt):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
# prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||||
|
prompt = "Write a 300 word description of how an apple tree grows.\n\n"
|
||||||
asyncio.run(print_response_stream(prompt))
|
asyncio.run(print_response_stream(prompt))
|
||||||
|
|
|
@ -28,7 +28,7 @@ from llm_server.routes.openai import openai_bp, openai_model_bp
|
||||||
from llm_server.routes.server_error import handle_server_error
|
from llm_server.routes.server_error import handle_server_error
|
||||||
from llm_server.routes.v1 import bp
|
from llm_server.routes.v1 import bp
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.sock import init_socketio
|
from llm_server.sock import init_wssocket
|
||||||
|
|
||||||
# TODO: queue item timeout
|
# TODO: queue item timeout
|
||||||
# TODO: return an `error: True`, error code, and error message rather than just a formatted message
|
# TODO: return an `error: True`, error code, and error message rather than just a formatted message
|
||||||
|
@ -68,10 +68,15 @@ except ModuleNotFoundError as e:
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
# Fixes ConcurrentObjectUseError
|
||||||
|
# https://github.com/miguelgrinberg/simple-websocket/issues/24
|
||||||
|
app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
|
||||||
|
|
||||||
app.register_blueprint(bp, url_prefix='/api/')
|
app.register_blueprint(bp, url_prefix='/api/')
|
||||||
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
|
||||||
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
|
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
|
||||||
init_socketio(app)
|
init_wssocket(app)
|
||||||
flask_cache.init_app(app)
|
flask_cache.init_app(app)
|
||||||
flask_cache.clear()
|
flask_cache.clear()
|
||||||
|
|
||||||
|
|
Reference in New Issue