get streaming working again

This commit is contained in:
Cyberes 2023-10-16 16:22:52 -06:00
parent 151b3e4769
commit 2c7773cc4f
13 changed files with 296 additions and 373 deletions

View File

@ -223,5 +223,14 @@ class RedisCustom(Redis):
self.flush() self.flush()
return True return True
def lrange(self, name: str, start: int, end: int):
return self.redis.lrange(self._key(name), start, end)
def delete(self, *names: KeyT):
return self.redis.delete(*[self._key(i) for i in names])
def lpop(self, name: str, count: Optional[int] = None):
return self.redis.lpop(self._key(name), count)
redis = RedisCustom('local_llm') redis = RedisCustom('local_llm')

View File

@ -5,6 +5,9 @@ from redis import Redis
def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False): def log_to_db(ip: str, token: str, prompt: str, response: Union[str, None], gen_time: Union[int, float, None], parameters: dict, headers: dict, backend_response_code: int, request_url: str, backend_url: str, response_tokens: int = None, is_error: bool = False):
assert isinstance(prompt, str)
assert isinstance(backend_url, str)
r = Redis(host='localhost', port=6379, db=3) r = Redis(host='localhost', port=6379, db=3)
data = { data = {
'function': 'log_prompt', 'function': 'log_prompt',

View File

@ -1,52 +0,0 @@
import json
from datetime import datetime, timedelta
import requests
from llm_server import opts
def get_power_states():
gpu_num = 0
output = {}
while True:
url = f"{opts.netdata_root}/api/v1/data?chart=nvidia_smi.gpu{gpu_num}_power_state"
try:
response = requests.get(url, timeout=10)
if response.status_code != 200:
break
data = json.loads(response.text)
power_state_data = data['data'][0]
power_state = None
for i in range(1, len(power_state_data)):
if power_state_data[i] == 1:
power_state = data['labels'][i]
break
output[f'gpu{gpu_num}'] = int(power_state.lower().strip('p'))
except Exception as e:
print('Failed to fetch Netdata metrics:', e)
return output
gpu_num += 1
return output
def get_gpu_wh(gpu_id: int):
chart_name = f"nvidia_smi.gpu{gpu_id}_power"
now = datetime.now()
one_hour_ago = now - timedelta(hours=1)
num_seconds = int((now - one_hour_ago).total_seconds())
params = {
"chart": chart_name,
"after": int(one_hour_ago.timestamp()),
"before": int(now.timestamp()),
"points": num_seconds,
"group": "second",
"format": "json",
"options": "absolute|jsonwrap"
}
response = requests.get(f'{opts.netdata_root}/api/v1/data', params=params, timeout=10)
data = json.loads(response.text)
total_power_usage_watts = sum(point[1] for point in data['result']['data'])
# total_power_usage_watt_hours = round(total_power_usage_watts / 3600, 1)
total_power_usage_kwh = round(total_power_usage_watts / 1000 / 3600, 3)
return total_power_usage_kwh

View File

@ -43,24 +43,23 @@ def openai_chat_completions(model_name=None):
if not opts.enable_streaming: if not opts.enable_streaming:
return 'Streaming disabled', 403 return 'Streaming disabled', 403
handler.parameters, _ = handler.get_parameters()
handler.request_json_body = {
'messages': handler.request_json_body['messages'],
'model': handler.request_json_body['model'],
**handler.parameters
}
invalid_oai_err_msg = validate_oai(handler.request_json_body) invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg: if invalid_oai_err_msg:
return invalid_oai_err_msg return invalid_oai_err_msg
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode']) handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=True, mode=handler.cluster_backend_info['mode'])
handler.parameters, e = handler.get_parameters()
handler.request_json_body = {
'messages': handler.request_json_body['messages'],
'model': handler.request_json_body['model'],
**handler.parameters
}
if opts.openai_silent_trim: if opts.openai_silent_trim:
handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)) handler.prompt = transform_messages_to_prompt(trim_messages_to_fit(handler.request.json['messages'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url))
else: else:
handler.prompt = transform_messages_to_prompt(handler.request.json['messages']) handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
if not handler.prompt: if not handler.prompt:
# Prevent issues on the backend. # Prevent issues on the backend.
return 'Invalid prompt', 400 return 'Invalid prompt', 400
@ -73,7 +72,7 @@ def openai_chat_completions(model_name=None):
request_valid, invalid_response = handler.validate_request() request_valid, invalid_response = handler.validate_request()
if not request_valid: if not request_valid:
return invalid_response return invalid_response
else:
event = None event = None
if not handler.is_client_ratelimited(): if not handler.is_client_ratelimited():
event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True) event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
@ -103,14 +102,16 @@ def openai_chat_completions(model_name=None):
stream_redis = Redis(db=8) stream_redis = Redis(db=8)
generated_text = '' generated_text = ''
try: try:
last_id = '0-0'
while True: while True:
stream_data = stream_redis.xread({stream_name: '0-0'}, block=30000) stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
if not stream_data: if not stream_data:
print("No message received in 30 seconds, closing stream.") print("No message received in 30 seconds, closing stream.")
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
else: else:
for r_timestamp, item in stream_data[0][1]: for stream_index, item in stream_data[0][1]:
timestamp = int(r_timestamp.decode('utf-8').split('-')[0]) last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = pickle.loads(item[b'data']) data = pickle.loads(item[b'data'])
if data['error']: if data['error']:
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
@ -154,6 +155,8 @@ def openai_chat_completions(model_name=None):
traceback.print_exc() traceback.print_exc()
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
finally: finally:
if event:
redis.lpush(f'notifications:{event.event_id}', 'canceled')
stream_redis.delete(stream_name) stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')

View File

@ -1,8 +1,10 @@
import pickle
import time import time
import traceback import traceback
import simplejson as json import simplejson as json
from flask import Response, jsonify, request from flask import Response, jsonify, request
from redis import Redis
from llm_server.custom_redis import redis from llm_server.custom_redis import redis
from . import openai_bp, openai_model_bp from . import openai_bp, openai_model_bp
@ -12,7 +14,6 @@ from ..queue import priority_queue
from ... import opts from ... import opts
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...llm import get_token_count from ...llm import get_token_count
from ...llm.generator import generator
from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai from ...llm.openai.oai_to_vllm import oai_to_vllm, return_invalid_model_err, validate_oai
from ...llm.openai.transform import generate_oai_string, trim_string_to_fit from ...llm.openai.transform import generate_oai_string, trim_string_to_fit
@ -42,12 +43,14 @@ def openai_completions(model_name=None):
handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode']) handler.request_json_body = oai_to_vllm(handler.request_json_body, stop_hashes=False, mode=handler.cluster_backend_info['mode'])
if opts.openai_silent_trim: if opts.openai_silent_trim:
handler.request_json_body['prompt'] = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url) handler.prompt = trim_string_to_fit(request_json_body['prompt'], handler.cluster_backend_info['model_config']['max_position_embeddings'], handler.backend_url)
else: else:
# The handle_request() call below will load the prompt so we don't have # The handle_request() call below will load the prompt so we don't have
# to do anything else here. # to do anything else here.
pass pass
handler.request_json_body['prompt'] = handler.prompt
if not request_json_body.get('stream'): if not request_json_body.get('stream'):
invalid_oai_err_msg = validate_oai(request_json_body) invalid_oai_err_msg = validate_oai(request_json_body)
if invalid_oai_err_msg: if invalid_oai_err_msg:
@ -89,24 +92,36 @@ def openai_completions(model_name=None):
if not opts.enable_streaming: if not opts.enable_streaming:
return 'Streaming disabled', 403 return 'Streaming disabled', 403
event_id = None request_valid, invalid_response = handler.validate_request()
if not request_valid:
return invalid_response
handler.parameters, _ = handler.get_parameters()
handler.request_json_body = {
'prompt': handler.request_json_body['prompt'],
'model': handler.request_json_body['model'],
**handler.parameters
}
invalid_oai_err_msg = validate_oai(handler.request_json_body)
if invalid_oai_err_msg:
return invalid_oai_err_msg
if opts.openai_silent_trim:
handler.request_json_body['prompt'] = handler.request_json_body['prompt'][:handler.cluster_backend_info['model_config']['max_position_embeddings']]
if not handler.prompt:
# Prevent issues on the backend.
return 'Invalid prompt', 400
start_time = time.time() start_time = time.time()
request_valid, invalid_response = handler.validate_request() request_valid, invalid_response = handler.validate_request()
if not request_valid: if not request_valid:
return invalid_response return invalid_response
else:
handler.prompt = handler.request_json_body['prompt']
msg_to_backend = {
**handler.parameters,
'prompt': handler.prompt,
'stream': True,
}
event = None event = None
if not handler.is_client_ratelimited(): if not handler.is_client_ratelimited():
# Add a dummy event to the queue and wait for it to reach a worker event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
if not event: if not event:
log_to_db( log_to_db(
handler.client_ip, handler.client_ip,
@ -122,68 +137,53 @@ def openai_completions(model_name=None):
) )
return handler.handle_ratelimited() return handler.handle_ratelimited()
# Wait for permission to begin.
event_id = event.event_id
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
for item in pubsub.listen():
if item['type'] == 'message':
msg = item['data'].decode('utf-8')
if msg == 'begin':
break
elif msg == 'offline':
return return_invalid_model_err(handler.request_json_body['model'])
time.sleep(0.1)
# Double check the model is still online
if not handler.check_online():
return return_invalid_model_err(handler.request_json_body['model'])
try: try:
response = generator(msg_to_backend, handler.backend_url)
r_headers = dict(request.headers) r_headers = dict(request.headers)
r_url = request.url r_url = request.url
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30) oai_string = generate_oai_string(30)
def generate(): def generate():
try: stream_name = event.wait()
stream_redis = Redis(db=8)
generated_text = '' generated_text = ''
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try: try:
json_obj = json.loads(json_str.decode()) last_id = '0-0'
new = json_obj['text'][0].split(handler.prompt + generated_text)[1] while True:
generated_text = generated_text + new stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
except IndexError: if not stream_data:
# ???? print("No message received in 30 seconds, closing stream.")
continue yield 'data: [DONE]\n\n'
else:
data = { for stream_index, item in stream_data[0][1]:
last_id = stream_index
timestamp = int(stream_index.decode('utf-8').split('-')[0])
data = pickle.loads(item[b'data'])
if data['error']:
yield 'data: [DONE]\n\n'
return
elif data['new']:
response = {
"id": f"cmpl-{oai_string}", "id": f"cmpl-{oai_string}",
"object": "text_completion", "object": "text_completion",
"created": int(time.time()), "created": timestamp,
"model": model, "model": model,
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"delta": { "delta": {
"content": new "content": data['new']
}, },
"finish_reason": None "finish_reason": None
} }
] ]
} }
yield f'data: {json.dumps(data)}\n\n' generated_text = generated_text + data['new']
yield f'data: {json.dumps(response)}\n\n'
elif data['completed']:
yield 'data: [DONE]\n\n' yield 'data: [DONE]\n\n'
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_to_db( log_to_db(
handler.client_ip, handler.client_ip,
handler.token, handler.token,
@ -196,11 +196,14 @@ def openai_completions(model_name=None):
r_url, r_url,
handler.backend_url, handler.backend_url,
) )
return
except (Exception, GeneratorExit):
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally: finally:
if event_id: if event:
redis.publish(event_id, 'finished') redis.lpush(f'notifications:{event.event_id}', 'canceled')
else: stream_redis.delete(stream_name)
print('event_id was None!')
return Response(generate(), mimetype='text/event-stream') return Response(generate(), mimetype='text/event-stream')
except Exception: except Exception:

View File

@ -150,10 +150,6 @@ class OpenAIRequestHandler(RequestHandler):
"total_tokens": prompt_tokens + response_tokens "total_tokens": prompt_tokens + response_tokens
} }
}), 200) }), 200)
stats = redis.get('proxy_stats', dtype=dict)
if stats:
response.headers['x-ratelimit-reset-requests'] = stats['queue']['estimated_wait_sec']
return response return response
def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]: def validate_request(self, prompt: str = None, do_log: bool = False) -> Tuple[bool, Tuple[Response | None, int]]:

View File

@ -37,6 +37,9 @@ class RequestHandler:
self.parameters = None self.parameters = None
self.used = False self.used = False
# This is null by default since most handlers need to transform the prompt in a specific way.
self.prompt = None
self.selected_model = selected_model self.selected_model = selected_model
self.backend_url = get_a_cluster_backend(selected_model) self.backend_url = get_a_cluster_backend(selected_model)
self.cluster_backend_info = cluster_config.get_backend(self.backend_url) self.cluster_backend_info = cluster_config.get_backend(self.backend_url)

View File

@ -1,17 +1,18 @@
import json import json
import pickle
import time import time
import traceback import traceback
from flask import request from flask import request
from redis import Redis
from . import bp from . import bp
from ..helpers.http import require_api_key, validate_json from ..helpers.http import require_api_key, validate_json
from ..ooba_request_handler import OobaRequestHandler from ..ooba_request_handler import OobaRequestHandler
from ..queue import priority_queue from ..queue import priority_queue
from ... import messages, opts from ... import opts
from ...custom_redis import redis from ...custom_redis import redis
from ...database.log_to_db import log_to_db from ...database.log_to_db import log_to_db
from ...llm.generator import generator
from ...sock import sock from ...sock import sock
@ -35,6 +36,7 @@ def stream_with_model(ws, model_name=None):
def do_stream(ws, model_name): def do_stream(ws, model_name):
event_id = None
try: try:
def send_err_and_quit(quitting_err_msg): def send_err_and_quit(quitting_err_msg):
ws.send(json.dumps({ ws.send(json.dumps({
@ -46,6 +48,7 @@ def do_stream(ws, model_name):
'event': 'stream_end', 'event': 'stream_end',
'message_num': 1 'message_num': 1
})) }))
ws.close()
log_to_db(ip=handler.client_ip, log_to_db(ip=handler.client_ip,
token=handler.token, token=handler.token,
prompt=input_prompt, prompt=input_prompt,
@ -55,7 +58,7 @@ def do_stream(ws, model_name):
headers=r_headers, headers=r_headers,
backend_response_code=response_status_code, backend_response_code=response_status_code,
request_url=r_url, request_url=r_url,
backend_url=handler.cluster_backend_info, backend_url=handler.backend_url,
response_tokens=None, response_tokens=None,
is_error=True is_error=True
) )
@ -74,6 +77,7 @@ def do_stream(ws, model_name):
if not request_valid_json or not request_json_body.get('prompt'): if not request_valid_json or not request_json_body.get('prompt'):
return 'Invalid JSON', 400 return 'Invalid JSON', 400
else: else:
# We have to do auth ourselves since the details are sent in the message.
auth_failure = require_api_key(request_json_body) auth_failure = require_api_key(request_json_body)
if auth_failure: if auth_failure:
return auth_failure return auth_failure
@ -89,14 +93,10 @@ def do_stream(ws, model_name):
})) }))
return return
assert not handler.offline
if handler.cluster_backend_info['mode'] != 'vllm': if handler.cluster_backend_info['mode'] != 'vllm':
# TODO: implement other backends # TODO: implement other backends
raise NotImplementedError raise NotImplementedError
event_id = None
generated_text = ''
input_prompt = request_json_body['prompt'] input_prompt = request_json_body['prompt']
response_status_code = 0 response_status_code = 0
start_time = time.time() start_time = time.time()
@ -113,119 +113,55 @@ def do_stream(ws, model_name):
send_err_and_quit(err_msg) send_err_and_quit(err_msg)
return return
llm_request = { handler.parameters, _ = handler.get_parameters()
**handler.parameters, handler.prompt = input_prompt
'prompt': input_prompt, handler.request_json_body = {
'stream': True, 'prompt': handler.prompt,
**handler.parameters
} }
event = None event = None
if not handler.is_client_ratelimited(): if not handler.is_client_ratelimited():
# Add a dummy event to the queue and wait for it to reach a worker event = priority_queue.put(handler.backend_url, (handler.request_json_body, handler.client_ip, handler.token, handler.parameters), handler.token_priority, handler.selected_model, do_stream=True)
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
if not event: if not event:
log_to_db( r = handler.handle_ratelimited()
handler.client_ip, send_err_and_quit(r[0].data)
handler.token,
handler.request_json_body.get('prompt'),
None,
None,
handler.parameters,
request.headers,
response_status_code,
request.url,
handler.backend_url,
)
return handler.handle_ratelimited()
# Wait for permission to begin.
event_id = event.event_id
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
for item in pubsub.listen():
if item['type'] == 'message':
msg = item['data'].decode('utf-8')
if msg == 'begin':
break
elif msg == 'offline':
return messages.BACKEND_OFFLINE, 404 # TODO: format this error
time.sleep(0.1)
# Double check the model is still online
if not handler.check_online():
return messages.BACKEND_OFFLINE, 404 # TODO: format this error
try:
response = generator(llm_request, handler.backend_url)
if not response:
error_msg = 'Failed to reach backend while streaming.'
print('Streaming failed:', error_msg)
msg = handler.handle_error(error_msg)[0].json['results'][0]['text']
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': msg
}))
else:
# Be extra careful when getting attributes from the response object
try:
response_status_code = response.status_code
except:
response_status_code = 0
partial_response = b''
for chunk in response.iter_content(chunk_size=1):
partial_response += chunk
if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00')
for json_str in json_strs:
if json_str:
try:
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(input_prompt + generated_text)[1]
generated_text = generated_text + new
except IndexError:
# ????
continue
try:
ws.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': new
}))
except:
# The has client closed the stream.
if response:
# Cancel the backend?
response.close()
# used to log here
return return
event_id = event.event_id
message_num += 1 stream_name = event.wait()
partial_response = b'' # Reset the partial response stream_redis = Redis(db=8)
generated_text = ''
# If there is no more data, break the loop try:
if not chunk: last_id = '0-0' # The ID of the last entry we read.
break while True:
if response: stream_data = stream_redis.xread({stream_name: last_id}, block=30000)
response.close() if not stream_data:
# used to log here print("No message received in 30 seconds, closing stream.")
except: return
traceback.print_exc() else:
generated_text = generated_text + '\n\n' + handler.handle_error('Encountered error while streaming.', 'exception')[0].json['results'][0]['text'] for stream_index, item in stream_data[0][1]:
last_id = stream_index
data = pickle.loads(item[b'data'])
if data['error']:
print(data['error'])
send_err_and_quit('Encountered exception while streaming.')
return
elif data['new']:
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'text_stream', 'event': 'text_stream',
'message_num': message_num, 'message_num': message_num,
'text': generated_text 'text': data['new']
})) }))
# used to log here message_num += 1
generated_text = generated_text + data['new']
elif data['completed']:
return
except:
send_err_and_quit('Encountered exception while streaming.')
traceback.print_exc()
finally: finally:
if event_id:
redis.publish(event_id, 'finished')
else:
print('event_id was None!')
try: try:
ws.send(json.dumps({ ws.send(json.dumps({
'event': 'stream_end', 'event': 'stream_end',
@ -234,6 +170,7 @@ def do_stream(ws, model_name):
except: except:
# The client closed the stream. # The client closed the stream.
pass pass
stream_redis.delete(stream_name)
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
log_to_db(ip=handler.client_ip, log_to_db(ip=handler.client_ip,
@ -248,6 +185,8 @@ def do_stream(ws, model_name):
backend_url=handler.backend_url backend_url=handler.backend_url
) )
finally: finally:
if event_id:
redis.lpush(f'notifications:{event_id}', 'canceled')
try: try:
# Must close the connection or greenlets will complain. # Must close the connection or greenlets will complain.
ws.close() ws.close()

View File

@ -3,6 +3,6 @@ from flask_sock import Sock
sock = Sock() sock = Sock()
def init_socketio(app): def init_wssocket(app):
global sock global sock
sock.init_app(app) sock.init_app(app)

View File

@ -7,7 +7,7 @@ from uuid import uuid4
from redis import Redis from redis import Redis
from llm_server.cluster.cluster_config import cluster_config from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom from llm_server.custom_redis import RedisCustom, redis
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
@ -20,15 +20,25 @@ def get_stream_name(name: str):
return f'{STREAM_NAME_PREFIX}:{name}' return f'{STREAM_NAME_PREFIX}:{name}'
def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str): def inference_do_stream(stream_name: str, msg_to_backend: dict, backend_url: str, event_id: str):
prompt = msg_to_backend['prompt'] prompt = msg_to_backend['prompt']
stream_name = get_stream_name(stream_name) stream_name = get_stream_name(stream_name)
redis.delete(f'notifications:{event_id}')
stream_redis.delete(get_stream_name(stream_name)) # be extra sure stream_redis.delete(get_stream_name(stream_name)) # be extra sure
try: try:
response = generator(msg_to_backend, backend_url) response = generator(msg_to_backend, backend_url)
generated_text = '' generated_text = ''
partial_response = b'' partial_response = b''
for chunk in response.iter_content(chunk_size=1): for chunk in response.iter_content(chunk_size=1):
# If there is no more data, break the loop
if not chunk:
break
message = redis.lpop(f'notifications:{event_id}')
if message and message.decode('utf-8') == 'canceled':
print('Client canceled generation')
response.close()
return
partial_response += chunk partial_response += chunk
if partial_response.endswith(b'\x00'): if partial_response.endswith(b'\x00'):
json_strs = partial_response.split(b'\x00') json_strs = partial_response.split(b'\x00')
@ -74,14 +84,16 @@ def worker(backend_url):
try: try:
if do_stream: if do_stream:
# Return the name of the stream that the slave should connect to.
event = DataEvent(event_id) event = DataEvent(event_id)
event.set(get_stream_name(worker_id)) event.set(get_stream_name(worker_id))
msg_to_backend = { msg_to_backend = {
**parameters, **parameters,
'prompt': request_json_body['prompt'], 'prompt': request_json_body['prompt'],
'stream': True, 'stream': True,
} }
inference_do_stream(worker_id, msg_to_backend, backend_url) inference_do_stream(worker_id, msg_to_backend, backend_url, event_id)
else: else:
# Normal inference (not streaming). # Normal inference (not streaming).
success, response, error_msg = generator(request_json_body, backend_url) success, response, error_msg = generator(request_json_body, backend_url)

View File

@ -29,4 +29,4 @@ def console_printer():
# TODO: Active Workers and Processing should read the same. If not, that's an issue # TODO: Active Workers and Processing should read the same. If not, that's an issue
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}') logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
time.sleep(10) time.sleep(2)

View File

@ -11,6 +11,7 @@ except ImportError:
HOST = 'localhost:5000' HOST = 'localhost:5000'
URI = f'ws://{HOST}/api/v1/stream' URI = f'ws://{HOST}/api/v1/stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss:// # For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' # URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
@ -82,5 +83,6 @@ async def print_response_stream(prompt):
if __name__ == '__main__': if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)" # prompt = "In order to make homemade bread, follow these steps:\n1)"
prompt = "Write a 300 word description of how an apple tree grows.\n\n"
asyncio.run(print_response_stream(prompt)) asyncio.run(print_response_stream(prompt))

View File

@ -28,7 +28,7 @@ from llm_server.routes.openai import openai_bp, openai_model_bp
from llm_server.routes.server_error import handle_server_error from llm_server.routes.server_error import handle_server_error
from llm_server.routes.v1 import bp from llm_server.routes.v1 import bp
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.sock import init_socketio from llm_server.sock import init_wssocket
# TODO: queue item timeout # TODO: queue item timeout
# TODO: return an `error: True`, error code, and error message rather than just a formatted message # TODO: return an `error: True`, error code, and error message rather than just a formatted message
@ -68,10 +68,15 @@ except ModuleNotFoundError as e:
sys.exit(1) sys.exit(1)
app = Flask(__name__) app = Flask(__name__)
# Fixes ConcurrentObjectUseError
# https://github.com/miguelgrinberg/simple-websocket/issues/24
app.config['SOCK_SERVER_OPTIONS'] = {'ping_interval': 25}
app.register_blueprint(bp, url_prefix='/api/') app.register_blueprint(bp, url_prefix='/api/')
app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/') app.register_blueprint(openai_bp, url_prefix='/api/openai/v1/')
app.register_blueprint(openai_model_bp, url_prefix='/api/openai/') app.register_blueprint(openai_model_bp, url_prefix='/api/openai/')
init_socketio(app) init_wssocket(app)
flask_cache.init_app(app) flask_cache.init_app(app)
flask_cache.clear() flask_cache.clear()