fix streaming?
This commit is contained in:
parent
67173f30dd
commit
e9f6fdf65e
|
@ -33,7 +33,6 @@ config_default_vars = {
|
|||
'openai_moderation_enabled': True,
|
||||
'netdata_root': None,
|
||||
'show_backends': True,
|
||||
'cluster_workers': 30,
|
||||
'background_homepage_cacher': True,
|
||||
'openai_moderation_timeout': 5,
|
||||
'prioritize_by_size': False
|
||||
|
|
|
@ -45,12 +45,15 @@ def load_config(config_path):
|
|||
opts.openai_silent_trim = config['openai_silent_trim']
|
||||
opts.openai_moderation_enabled = config['openai_moderation_enabled']
|
||||
opts.show_backends = config['show_backends']
|
||||
opts.cluster_workers = config['cluster_workers']
|
||||
opts.background_homepage_cacher = config['background_homepage_cacher']
|
||||
opts.openai_moderation_timeout = config['openai_moderation_timeout']
|
||||
opts.frontend_api_mode = config['frontend_api_mode']
|
||||
opts.prioritize_by_size = config['prioritize_by_size']
|
||||
|
||||
# Scale the number of workers.
|
||||
for item in config['cluster']:
|
||||
opts.cluster_workers += item['concurrent_gens']
|
||||
|
||||
if opts.openai_expose_our_model and not opts.openai_api_key:
|
||||
print('If you set openai_epose_our_model to false, you must set your OpenAI key in openai_api_key.')
|
||||
sys.exit(1)
|
||||
|
|
|
@ -34,7 +34,7 @@ openai_silent_trim = False
|
|||
openai_moderation_enabled = True
|
||||
cluster = {}
|
||||
show_backends = True
|
||||
cluster_workers = 30
|
||||
background_homepage_cacher = True
|
||||
openai_moderation_timeout = 5
|
||||
prioritize_by_size = False
|
||||
prioritize_by_size = False
|
||||
cluster_workers = 0
|
|
@ -8,7 +8,7 @@ from llm_server.custom_redis import redis
|
|||
from . import openai_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..openai_request_handler import OpenAIRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.generator import generator
|
||||
|
@ -57,6 +57,7 @@ def openai_chat_completions():
|
|||
else:
|
||||
handler.prompt = transform_messages_to_prompt(handler.request.json['messages'])
|
||||
|
||||
event_id = None
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -70,8 +71,10 @@ def openai_chat_completions():
|
|||
'stream': True,
|
||||
}
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
if not event:
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
|
@ -87,8 +90,15 @@ def openai_chat_completions():
|
|||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
# Wait for a worker to get our request and discard it.
|
||||
_, _, _ = event.wait()
|
||||
# Once the worker receives our streaming request, it will tell us we are ready
|
||||
# to begin inference.
|
||||
event_id = event.event_id
|
||||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(event_id)
|
||||
for item in pubsub.listen():
|
||||
if item['type'] == 'message' and item['data'].decode('utf-8') == 'begin':
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
r_headers = dict(request.headers)
|
||||
|
@ -97,61 +107,63 @@ def openai_chat_completions():
|
|||
oai_string = generate_oai_string(30)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
data = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
response_status_code,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
finally:
|
||||
# The worker incremented it, we'll decrement it.
|
||||
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||
data = {
|
||||
"id": f"chatcmpl-{oai_string}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
response_status_code,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'INTERNAL SERVER', 500
|
||||
finally:
|
||||
# After completing inference, we need to tell the worker we
|
||||
# are finished.
|
||||
if event_id: # may be None if ratelimited.
|
||||
redis.publish(event_id, 'finished')
|
||||
else:
|
||||
print('event_id was None!')
|
||||
|
|
|
@ -8,9 +8,8 @@ from llm_server.custom_redis import redis
|
|||
from . import openai_bp
|
||||
from ..helpers.http import validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...database.database import do_db_log
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm import get_token_count
|
||||
from ...llm.generator import generator
|
||||
|
@ -53,7 +52,6 @@ def openai_completions():
|
|||
return handler.handle_ratelimited()
|
||||
output = response.json['results'][0]['text']
|
||||
|
||||
# TODO: async/await
|
||||
prompt_tokens = get_token_count(request_json_body['prompt'], handler.backend_url)
|
||||
response_tokens = get_token_count(output, handler.backend_url)
|
||||
running_model = redis.get('running_model', 'ERROR', dtype=str)
|
||||
|
@ -86,6 +84,7 @@ def openai_completions():
|
|||
if not opts.enable_streaming:
|
||||
return 'DISABLED', 401
|
||||
|
||||
event_id = None
|
||||
response_status_code = 0
|
||||
start_time = time.time()
|
||||
|
||||
|
@ -100,8 +99,10 @@ def openai_completions():
|
|||
'stream': True,
|
||||
}
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
if not event:
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
|
@ -117,8 +118,14 @@ def openai_completions():
|
|||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
# Wait for a worker to get our request and discard it.
|
||||
_, _, _ = event.wait()
|
||||
# Wait for permission to begin.
|
||||
event_id = event.event_id
|
||||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(event_id)
|
||||
for item in pubsub.listen():
|
||||
if item['type'] == 'message' and item['data'].decode('utf-8') == 'begin':
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
response = generator(msg_to_backend, handler.backend_url)
|
||||
|
@ -128,61 +135,61 @@ def openai_completions():
|
|||
oai_string = generate_oai_string(30)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
generated_text = ''
|
||||
partial_response = b''
|
||||
for chunk in response.iter_content(chunk_size=1):
|
||||
partial_response += chunk
|
||||
if partial_response.endswith(b'\x00'):
|
||||
json_strs = partial_response.split(b'\x00')
|
||||
for json_str in json_strs:
|
||||
if json_str:
|
||||
try:
|
||||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
||||
data = {
|
||||
"id": f"cmpl-{oai_string}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
data = {
|
||||
"id": f"cmpl-{oai_string}",
|
||||
"object": "text_completion",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": new
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
}
|
||||
yield f'data: {json.dumps(data)}\n\n'
|
||||
yield 'data: [DONE]\n\n'
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
response_status_code,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
finally:
|
||||
# The worker incremented it, we'll decrement it.
|
||||
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.prompt,
|
||||
generated_text,
|
||||
elapsed_time,
|
||||
handler.parameters,
|
||||
r_headers,
|
||||
response_status_code,
|
||||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
|
||||
return Response(generate(), mimetype='text/event-stream')
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
return 'INTERNAL SERVER', 500
|
||||
finally:
|
||||
if event_id:
|
||||
redis.publish(event_id, 'finished')
|
||||
else:
|
||||
print('event_id was None!')
|
||||
|
|
|
@ -22,8 +22,6 @@ def decrement_ip_count(client_ip: str, redis_key):
|
|||
class RedisPriorityQueue:
|
||||
def __init__(self, name: str = 'priority_queue', db: int = 12):
|
||||
self.redis = RedisCustom(name, db=db)
|
||||
self.pubsub = self.redis.pubsub()
|
||||
self.pubsub.subscribe('events')
|
||||
|
||||
def put(self, item, priority, selected_model):
|
||||
event = DataEvent()
|
||||
|
@ -36,8 +34,6 @@ class RedisPriorityQueue:
|
|||
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
||||
return None # reject the request
|
||||
|
||||
print('--->', event.event_id)
|
||||
|
||||
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority})
|
||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||
return event
|
||||
|
@ -54,17 +50,14 @@ class RedisPriorityQueue:
|
|||
|
||||
def print_all_items(self):
|
||||
items = self.redis.zrange('queue', 0, -1)
|
||||
print(items)
|
||||
for item in items:
|
||||
print(item.decode('utf-8'))
|
||||
|
||||
def increment_ip_count(self, client_ip: str, redis_key):
|
||||
new_count = self.redis.hincrby(redis_key, client_ip, 1)
|
||||
print(client_ip, new_count)
|
||||
|
||||
def decrement_ip_count(self, client_ip: str, redis_key):
|
||||
new_count = self.redis.hincrby(redis_key, client_ip, -1)
|
||||
print(client_ip, new_count)
|
||||
if new_count <= 0:
|
||||
self.redis.hdel(redis_key, client_ip)
|
||||
|
||||
|
|
|
@ -7,8 +7,9 @@ from flask import request
|
|||
from . import bp
|
||||
from ..helpers.http import require_api_key, validate_json
|
||||
from ..ooba_request_handler import OobaRequestHandler
|
||||
from ..queue import decr_active_workers, decrement_ip_count, priority_queue
|
||||
from ..queue import priority_queue
|
||||
from ... import opts
|
||||
from ...custom_redis import redis
|
||||
from ...database.log_to_db import log_to_db
|
||||
from ...llm.generator import generator
|
||||
from ...sock import sock
|
||||
|
@ -94,6 +95,7 @@ def do_stream(ws, model_name):
|
|||
# TODO: implement other backends
|
||||
raise NotImplementedError
|
||||
|
||||
event_id = None
|
||||
generated_text = ''
|
||||
input_prompt = request_json_body['prompt']
|
||||
response_status_code = 0
|
||||
|
@ -117,16 +119,33 @@ def do_stream(ws, model_name):
|
|||
'stream': True,
|
||||
}
|
||||
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put((None, handler.client_ip, handler.token, None, handler.backend_url), handler.token_priority, handler.selected_model)
|
||||
if not event:
|
||||
r, _ = handler.handle_ratelimited()
|
||||
err_msg = r.json['results'][0]['text']
|
||||
send_err_and_quit(err_msg)
|
||||
return
|
||||
log_to_db(
|
||||
handler.client_ip,
|
||||
handler.token,
|
||||
handler.request_json_body.get('prompt'),
|
||||
None,
|
||||
None,
|
||||
handler.parameters,
|
||||
request.headers,
|
||||
response_status_code,
|
||||
request.url,
|
||||
handler.backend_url,
|
||||
)
|
||||
return handler.handle_ratelimited()
|
||||
|
||||
# Wait for a worker to get our request and discard it.
|
||||
_, _, _ = event.wait()
|
||||
# Wait for permission to begin.
|
||||
event_id = event.event_id
|
||||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(event_id)
|
||||
for item in pubsub.listen():
|
||||
if item['type'] == 'message' and item['data'].decode('utf-8') == 'begin':
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
try:
|
||||
response = generator(llm_request, handler.backend_url)
|
||||
|
@ -195,9 +214,11 @@ def do_stream(ws, model_name):
|
|||
}))
|
||||
# used to log here
|
||||
finally:
|
||||
# The worker incremented it, we'll decrement it.
|
||||
decrement_ip_count(handler.client_ip, 'processing_ips')
|
||||
decr_active_workers(handler.selected_model, handler.backend_url)
|
||||
if event_id:
|
||||
redis.publish(event_id, 'finished')
|
||||
else:
|
||||
print('event_id was None!')
|
||||
|
||||
try:
|
||||
ws.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
|
|
|
@ -19,27 +19,30 @@ def worker():
|
|||
if not selected_model:
|
||||
selected_model = backend_info['model']
|
||||
|
||||
# This wait time will be "invisible", meaning the worker may as
|
||||
# well be still waiting to get an item from the queue.
|
||||
need_to_wait(backend_url)
|
||||
|
||||
increment_ip_count(client_ip, 'processing_ips')
|
||||
incr_active_workers(selected_model, backend_url)
|
||||
|
||||
print('<---', event_id)
|
||||
|
||||
if not request_json_body:
|
||||
# This was a dummy request from the websocket handlers.
|
||||
# We're going to let the websocket handler decrement
|
||||
# processing_ips and active_gen_workers.
|
||||
event = DataEvent(event_id)
|
||||
event.set((True, None, None))
|
||||
continue
|
||||
|
||||
try:
|
||||
success, response, error_msg = generator(request_json_body, backend_url)
|
||||
event = DataEvent(event_id)
|
||||
event.set((success, response, error_msg))
|
||||
if not request_json_body:
|
||||
# This was a dummy request from the streaming handlers.
|
||||
# The worker will let the handler do the streaming instead
|
||||
# of the worker. The worker will block until the handler
|
||||
# is finished. Since a lot of ratelimiting and stats are
|
||||
# based off the number of active workers, we must keep
|
||||
# the generation based off the workers.
|
||||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(event_id)
|
||||
redis.publish(event_id, 'begin')
|
||||
for item in pubsub.listen():
|
||||
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
|
||||
# Once the handler is complete, move on.
|
||||
break
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
# Normal inference (not streaming).
|
||||
success, response, error_msg = generator(request_json_body, backend_url)
|
||||
event = DataEvent(event_id)
|
||||
event.set((success, response, error_msg))
|
||||
finally:
|
||||
decrement_ip_count(client_ip, 'processing_ips')
|
||||
decr_active_workers(selected_model, backend_url)
|
||||
|
@ -53,16 +56,3 @@ def start_workers(num_workers: int):
|
|||
t.start()
|
||||
i += 1
|
||||
print(f'Started {i} inference workers.')
|
||||
|
||||
|
||||
def need_to_wait(backend_url: str):
|
||||
# We need to check the number of active workers since the streaming endpoint may be doing something.
|
||||
active_workers = redis.get(f'active_gen_workers:{backend_url}', 0, dtype=int)
|
||||
concurrent_gens = cluster_config.get_backend(backend_url).get('concurrent_gens', 1)
|
||||
s = time.time()
|
||||
print(active_workers)
|
||||
while active_workers >= concurrent_gens:
|
||||
time.sleep(0.01)
|
||||
e = time.time()
|
||||
if e - s > 0.1:
|
||||
print(f'Worker was delayed {e - s} seconds.')
|
||||
|
|
Reference in New Issue