fix issues with queue and streaming

This commit is contained in:
Cyberes 2023-10-15 20:45:01 -06:00
parent 3ec9b2347f
commit 31ab4188f1
9 changed files with 119 additions and 98 deletions

View File

@ -200,10 +200,10 @@ class RedisCustom(Redis):
return json.loads(r.decode("utf-8"))
def setp(self, name, value):
self.redis.set(name, pickle.dumps(value))
self.redis.set(self._key(name), pickle.dumps(value))
def getp(self, name: str):
r = self.redis.get(name)
r = self.redis.get(self._key(name))
if r:
return pickle.loads(r)
return r

View File

@ -1,79 +1,6 @@
from flask import jsonify
from llm_server.custom_redis import redis
from ..llm_backend import LLMBackend
from ...database.database import do_db_log
from ...helpers import safe_list_get
from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json
class OobaboogaBackend(LLMBackend):
default_params = {}
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
raise NotImplementedError('need to implement default_params')
backend_err = False
response_valid_json, response_json_body = validate_json(response)
if response:
try:
# Be extra careful when getting attributes from the response object
response_status_code = response.status_code
except:
response_status_code = 0
else:
response_status_code = None
# ===============================================
# We encountered an error
if not success or not response or error_msg:
if not error_msg or error_msg == '':
error_msg = 'Unknown error.'
else:
error_msg = error_msg.strip('.') + '.'
backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url)
log_to_db(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': error_msg,
'results': [{'text': backend_response}]
}), 400
# ===============================================
if response_valid_json:
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
# Ooba doesn't return any error messages so we will just tell the client an error occurred
backend_err = True
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
error_type='error',
backend_url=self.backend_url)
response_json_body['results'][0]['text'] = backend_response
if not backend_err:
redis.incr('proompts')
log_to_db(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({
**response_json_body
}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url)
log_to_db(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 400
def validate_params(self, params_dict: dict):
# No validation required
return True, None
def get_parameters(self, parameters):
del parameters['prompt']
return parameters
def __int__(self):
return

View File

@ -33,7 +33,7 @@ def tokenize(prompt: str, backend_url: str) -> int:
j = r.json()
return j['length']
except Exception as e:
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
print(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
return len(tokenizer.encode(chunk)) + 10
# Use a ThreadPoolExecutor to send all chunks to the server at once

View File

@ -17,7 +17,7 @@ class OobaRequestHandler(RequestHandler):
assert not self.used
if self.offline:
print(messages.BACKEND_OFFLINE)
self.handle_error(messages.BACKEND_OFFLINE)
return self.handle_error(messages.BACKEND_OFFLINE)
request_valid, invalid_response = self.validate_request()
if not request_valid:

View File

@ -79,6 +79,7 @@ def openai_chat_completions(model_name=None):
event = None
if not handler.is_client_ratelimited():
start_time = time.time()
# Add a dummy event to the queue and wait for it to reach a worker
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
if not event:
@ -102,11 +103,14 @@ def openai_chat_completions(model_name=None):
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
for item in pubsub.listen():
if time.time() - start_time >= opts.backend_generate_request_timeout:
raise Exception('Inferencer timed out waiting for streaming to complete:', request_json_body)
if item['type'] == 'message':
msg = item['data'].decode('utf-8')
if msg == 'begin':
break
elif msg == 'offline':
# This shouldn't happen because the best model should be auto-selected.
return return_invalid_model_err(handler.request_json_body['model'])
time.sleep(0.1)
@ -135,6 +139,7 @@ def openai_chat_completions(model_name=None):
json_obj = json.loads(json_str.decode())
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
generated_text = generated_text + new
redis.publish(event_id, 'chunk') # Keepalive
except IndexError:
# ????
continue
@ -170,9 +175,14 @@ def openai_chat_completions(model_name=None):
r_url,
handler.backend_url,
)
except GeneratorExit:
yield 'data: [DONE]\n\n'
except:
# AttributeError: 'bool' object has no attribute 'iter_content'
traceback.print_exc()
yield 'data: [DONE]\n\n'
finally:
# After completing inference, we need to tell the worker we
# are finished.
# After completing inference, we need to tell the worker we are finished.
if event_id: # may be None if ratelimited.
redis.publish(event_id, 'finished')
else:

View File

@ -6,6 +6,7 @@ from uuid import uuid4
from redis import Redis
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import RedisCustom, redis
from llm_server.database.database import get_token_ratelimit
@ -23,9 +24,14 @@ def decrement_ip_count(client_ip: str, redis_key):
class RedisPriorityQueue:
def __init__(self, name, db: int = 12):
self.name = name
self.redis = RedisCustom(name, db=db)
def put(self, item, priority, selected_model):
assert item is not None
assert priority is not None
assert selected_model is not None
event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1])
@ -36,7 +42,8 @@ class RedisPriorityQueue:
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
return None # reject the request
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority})
timestamp = time.time()
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp)): -priority})
self.increment_ip_count(item[1], 'queued_ip_count')
return event
@ -52,11 +59,13 @@ class RedisPriorityQueue:
def print_all_items(self):
items = self.redis.zrange('queue', 0, -1)
to_print = []
for item in items:
print(item.decode('utf-8'))
to_print.append(item.decode('utf-8'))
print(f'ITEMS {self.name} -->', to_print)
def increment_ip_count(self, client_ip: str, redis_key):
new_count = self.redis.hincrby(redis_key, client_ip, 1)
self.redis.hincrby(redis_key, client_ip, 1)
def decrement_ip_count(self, client_ip: str, redis_key):
new_count = self.redis.hincrby(redis_key, client_ip, -1)
@ -75,6 +84,16 @@ class RedisPriorityQueue:
def flush(self):
self.redis.flush()
def cleanup(self):
now = time.time()
items = self.redis.zrange('queue', 0, -1)
for item in items:
item_data = json.loads(item)
timestamp = item_data[-1]
if now - timestamp > opts.backend_generate_request_timeout * 3: # TODO: config option
self.redis.zrem('queue', item)
print('removed item from queue:', item)
class DataEvent:
def __init__(self, event_id=None):
@ -112,7 +131,7 @@ def decr_active_workers(selected_model: str, backend_url: str):
class PriorityQueue:
def __init__(self, backends: list = None):
def __init__(self, backends: set = None):
"""
Only have to load the backends once.
:param backends:
@ -120,10 +139,10 @@ class PriorityQueue:
self.redis = Redis(host='localhost', port=6379, db=9)
if backends:
for item in backends:
self.redis.lpush('backends', item)
self.redis.sadd('backends', item)
def get_backends(self):
return [x.decode('utf-8') for x in self.redis.lrange('backends', 0, -1)]
return {x.decode('utf-8') for x in self.redis.smembers('backends')}
def get_queued_ip_count(self, client_ip: str):
count = 0
@ -136,22 +155,32 @@ class PriorityQueue:
queue = RedisPriorityQueue(backend_url)
return queue.put(item, priority, selected_model)
def activity(self):
lines = []
status_redis = RedisCustom('worker_status')
for worker in status_redis.keys():
lines.append((worker, status_redis.getp(worker)))
return sorted(lines)
def len(self, model_name):
count = 0
backends_with_models = []
backends_with_models = set()
for k in self.get_backends():
info = cluster_config.get_backend(k)
if info.get('model') == model_name:
backends_with_models.append(k)
backends_with_models.add(k)
for backend_url in backends_with_models:
count += len(RedisPriorityQueue(backend_url))
return count
def __len__(self):
count = 0
p = set()
for backend_url in self.get_backends():
queue = RedisPriorityQueue(backend_url)
p.add((backend_url, len(queue)))
count += len(queue)
print(p)
return count
def flush(self):

View File

@ -1,20 +1,48 @@
import queue
import threading
import time
import traceback
from uuid import uuid4
from redis.client import PubSub
from llm_server import opts
from llm_server.cluster.cluster_config import cluster_config
from llm_server.custom_redis import redis
from llm_server.custom_redis import RedisCustom, redis
from llm_server.llm.generator import generator
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
class ListenerThread(threading.Thread):
def __init__(self, pubsub: PubSub, listener_queue: queue.Queue, stop_event: threading.Event):
threading.Thread.__init__(self)
self.pubsub = pubsub
self.listener_queue = listener_queue
self.stop_event = stop_event
def run(self):
while not self.stop_event.is_set():
message = self.pubsub.get_message()
if message:
self.listener_queue.put(message)
time.sleep(0.1)
def worker(backend_url):
queue = RedisPriorityQueue(backend_url)
status_redis = RedisCustom('worker_status')
worker_id = uuid4()
status_redis.setp(str(worker_id), None)
redis_queue = RedisPriorityQueue(backend_url)
while True:
(request_json_body, client_ip, token, parameters), event_id, selected_model = queue.get()
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp = redis_queue.get()
backend_info = cluster_config.get_backend(backend_url)
pubsub = redis.pubsub()
pubsub.subscribe(event_id)
stop_event = threading.Event()
q = queue.Queue()
listener = ListenerThread(pubsub, q, stop_event)
listener.start()
if not backend_info['online']:
redis.publish(event_id, 'offline')
@ -26,6 +54,8 @@ def worker(backend_url):
increment_ip_count(client_ip, 'processing_ips')
incr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), (backend_url, client_ip))
try:
if not request_json_body:
# This was a dummy request from the streaming handlers.
@ -34,13 +64,27 @@ def worker(backend_url):
# is finished. Since a lot of ratelimiting and stats are
# based off the number of active workers, we must keep
# the generation based off the workers.
start_time = time.time()
redis.publish(event_id, 'begin')
for item in pubsub.listen():
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
# The streaming endpoint has said that it has finished
while True:
status_redis.setp(str(worker_id), (f'waiting for streaming to complete - {time.time() - start_time} - {opts.backend_generate_request_timeout}', client_ip))
try:
item = q.get(timeout=30)
except queue.Empty:
print('Inferencer timed out waiting for chunk from streamer:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
status_redis.setp(str(worker_id), ('streaming chunk timed out', client_ip))
break
if time.time() - start_time >= opts.backend_generate_request_timeout:
status_redis.setp(str(worker_id), ('streaming timed out', client_ip))
print('Inferencer timed out waiting for streaming to complete:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
break
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
status_redis.setp(str(worker_id), ('streaming completed', client_ip))
break
time.sleep(0.1)
else:
status_redis.setp(str(worker_id), ('generating', client_ip))
# Normal inference (not streaming).
success, response, error_msg = generator(request_json_body, backend_url)
event = DataEvent(event_id)
@ -48,8 +92,11 @@ def worker(backend_url):
except:
traceback.print_exc()
finally:
stop_event.set() # make sure to stop the listener thread
listener.join()
decrement_ip_count(client_ip, 'processing_ips')
decr_active_workers(selected_model, backend_url)
status_redis.setp(str(worker_id), None)
def start_workers(cluster: dict):

View File

@ -7,6 +7,7 @@ from llm_server.cluster.cluster_config import cluster_config, get_backends
from llm_server.custom_redis import redis
from llm_server.database.database import weighted_average_column_for_model
from llm_server.llm.info import get_info
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
def main_background_thread():
@ -35,6 +36,11 @@ def main_background_thread():
except Exception as e:
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
backends = priority_queue.get_backends()
for backend_url in backends:
queue = RedisPriorityQueue(backend_url)
queue.cleanup()
time.sleep(30)

View File

@ -24,5 +24,7 @@ def console_printer():
for k in processing:
processing_count += redis.get(k, default=0, dtype=int)
backends = [k for k, v in cluster_config.all().items() if v['online']]
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
time.sleep(10)
activity = priority_queue.activity()
print(activity)
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
time.sleep(1)