fix issues with queue and streaming
This commit is contained in:
parent
3ec9b2347f
commit
31ab4188f1
|
@ -200,10 +200,10 @@ class RedisCustom(Redis):
|
|||
return json.loads(r.decode("utf-8"))
|
||||
|
||||
def setp(self, name, value):
|
||||
self.redis.set(name, pickle.dumps(value))
|
||||
self.redis.set(self._key(name), pickle.dumps(value))
|
||||
|
||||
def getp(self, name: str):
|
||||
r = self.redis.get(name)
|
||||
r = self.redis.get(self._key(name))
|
||||
if r:
|
||||
return pickle.loads(r)
|
||||
return r
|
||||
|
|
|
@ -1,79 +1,6 @@
|
|||
from flask import jsonify
|
||||
|
||||
from llm_server.custom_redis import redis
|
||||
from ..llm_backend import LLMBackend
|
||||
from ...database.database import do_db_log
|
||||
from ...helpers import safe_list_get
|
||||
from ...routes.helpers.client import format_sillytavern_err
|
||||
from ...routes.helpers.http import validate_json
|
||||
|
||||
|
||||
class OobaboogaBackend(LLMBackend):
|
||||
default_params = {}
|
||||
|
||||
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||||
raise NotImplementedError('need to implement default_params')
|
||||
|
||||
backend_err = False
|
||||
response_valid_json, response_json_body = validate_json(response)
|
||||
if response:
|
||||
try:
|
||||
# Be extra careful when getting attributes from the response object
|
||||
response_status_code = response.status_code
|
||||
except:
|
||||
response_status_code = 0
|
||||
else:
|
||||
response_status_code = None
|
||||
|
||||
# ===============================================
|
||||
|
||||
# We encountered an error
|
||||
if not success or not response or error_msg:
|
||||
if not error_msg or error_msg == '':
|
||||
error_msg = 'Unknown error.'
|
||||
else:
|
||||
error_msg = error_msg.strip('.') + '.'
|
||||
backend_response = format_sillytavern_err(error_msg, error_type='error', backend_url=self.backend_url)
|
||||
log_to_db(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'msg': error_msg,
|
||||
'results': [{'text': backend_response}]
|
||||
}), 400
|
||||
|
||||
# ===============================================
|
||||
|
||||
if response_valid_json:
|
||||
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
|
||||
if not backend_response:
|
||||
# Ooba doesn't return any error messages so we will just tell the client an error occurred
|
||||
backend_err = True
|
||||
backend_response = format_sillytavern_err(
|
||||
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
|
||||
error_type='error',
|
||||
backend_url=self.backend_url)
|
||||
response_json_body['results'][0]['text'] = backend_response
|
||||
|
||||
if not backend_err:
|
||||
redis.incr('proompts')
|
||||
|
||||
log_to_db(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
||||
return jsonify({
|
||||
**response_json_body
|
||||
}), 200
|
||||
else:
|
||||
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', error_type='error', backend_url=self.backend_url)
|
||||
log_to_db(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
|
||||
return jsonify({
|
||||
'code': 500,
|
||||
'msg': 'the backend did not return valid JSON',
|
||||
'results': [{'text': backend_response}]
|
||||
}), 400
|
||||
|
||||
def validate_params(self, params_dict: dict):
|
||||
# No validation required
|
||||
return True, None
|
||||
|
||||
def get_parameters(self, parameters):
|
||||
del parameters['prompt']
|
||||
return parameters
|
||||
def __int__(self):
|
||||
return
|
||||
|
|
|
@ -33,7 +33,7 @@ def tokenize(prompt: str, backend_url: str) -> int:
|
|||
j = r.json()
|
||||
return j['length']
|
||||
except Exception as e:
|
||||
print(f'Failed to tokenize using VLLM -', f'{e.__class__.__name__}: {e}')
|
||||
print(f'Failed to tokenize using VLLM - {e.__class__.__name__}')
|
||||
return len(tokenizer.encode(chunk)) + 10
|
||||
|
||||
# Use a ThreadPoolExecutor to send all chunks to the server at once
|
||||
|
|
|
@ -17,7 +17,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
assert not self.used
|
||||
if self.offline:
|
||||
print(messages.BACKEND_OFFLINE)
|
||||
self.handle_error(messages.BACKEND_OFFLINE)
|
||||
return self.handle_error(messages.BACKEND_OFFLINE)
|
||||
|
||||
request_valid, invalid_response = self.validate_request()
|
||||
if not request_valid:
|
||||
|
|
|
@ -79,6 +79,7 @@ def openai_chat_completions(model_name=None):
|
|||
|
||||
event = None
|
||||
if not handler.is_client_ratelimited():
|
||||
start_time = time.time()
|
||||
# Add a dummy event to the queue and wait for it to reach a worker
|
||||
event = priority_queue.put(handler.backend_url, (None, handler.client_ip, handler.token, None), handler.token_priority, handler.selected_model)
|
||||
if not event:
|
||||
|
@ -102,11 +103,14 @@ def openai_chat_completions(model_name=None):
|
|||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(event_id)
|
||||
for item in pubsub.listen():
|
||||
if time.time() - start_time >= opts.backend_generate_request_timeout:
|
||||
raise Exception('Inferencer timed out waiting for streaming to complete:', request_json_body)
|
||||
if item['type'] == 'message':
|
||||
msg = item['data'].decode('utf-8')
|
||||
if msg == 'begin':
|
||||
break
|
||||
elif msg == 'offline':
|
||||
# This shouldn't happen because the best model should be auto-selected.
|
||||
return return_invalid_model_err(handler.request_json_body['model'])
|
||||
time.sleep(0.1)
|
||||
|
||||
|
@ -135,6 +139,7 @@ def openai_chat_completions(model_name=None):
|
|||
json_obj = json.loads(json_str.decode())
|
||||
new = json_obj['text'][0].split(handler.prompt + generated_text)[1]
|
||||
generated_text = generated_text + new
|
||||
redis.publish(event_id, 'chunk') # Keepalive
|
||||
except IndexError:
|
||||
# ????
|
||||
continue
|
||||
|
@ -170,9 +175,14 @@ def openai_chat_completions(model_name=None):
|
|||
r_url,
|
||||
handler.backend_url,
|
||||
)
|
||||
except GeneratorExit:
|
||||
yield 'data: [DONE]\n\n'
|
||||
except:
|
||||
# AttributeError: 'bool' object has no attribute 'iter_content'
|
||||
traceback.print_exc()
|
||||
yield 'data: [DONE]\n\n'
|
||||
finally:
|
||||
# After completing inference, we need to tell the worker we
|
||||
# are finished.
|
||||
# After completing inference, we need to tell the worker we are finished.
|
||||
if event_id: # may be None if ratelimited.
|
||||
redis.publish(event_id, 'finished')
|
||||
else:
|
||||
|
|
|
@ -6,6 +6,7 @@ from uuid import uuid4
|
|||
|
||||
from redis import Redis
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
from llm_server.database.database import get_token_ratelimit
|
||||
|
@ -23,9 +24,14 @@ def decrement_ip_count(client_ip: str, redis_key):
|
|||
|
||||
class RedisPriorityQueue:
|
||||
def __init__(self, name, db: int = 12):
|
||||
self.name = name
|
||||
self.redis = RedisCustom(name, db=db)
|
||||
|
||||
def put(self, item, priority, selected_model):
|
||||
assert item is not None
|
||||
assert priority is not None
|
||||
assert selected_model is not None
|
||||
|
||||
event = DataEvent()
|
||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||
ip_count = self.redis.hget('queued_ip_count', item[1])
|
||||
|
@ -36,7 +42,8 @@ class RedisPriorityQueue:
|
|||
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
||||
return None # reject the request
|
||||
|
||||
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model)): -priority})
|
||||
timestamp = time.time()
|
||||
self.redis.zadd('queue', {json.dumps((item, event.event_id, selected_model, timestamp)): -priority})
|
||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||
return event
|
||||
|
||||
|
@ -52,11 +59,13 @@ class RedisPriorityQueue:
|
|||
|
||||
def print_all_items(self):
|
||||
items = self.redis.zrange('queue', 0, -1)
|
||||
to_print = []
|
||||
for item in items:
|
||||
print(item.decode('utf-8'))
|
||||
to_print.append(item.decode('utf-8'))
|
||||
print(f'ITEMS {self.name} -->', to_print)
|
||||
|
||||
def increment_ip_count(self, client_ip: str, redis_key):
|
||||
new_count = self.redis.hincrby(redis_key, client_ip, 1)
|
||||
self.redis.hincrby(redis_key, client_ip, 1)
|
||||
|
||||
def decrement_ip_count(self, client_ip: str, redis_key):
|
||||
new_count = self.redis.hincrby(redis_key, client_ip, -1)
|
||||
|
@ -75,6 +84,16 @@ class RedisPriorityQueue:
|
|||
def flush(self):
|
||||
self.redis.flush()
|
||||
|
||||
def cleanup(self):
|
||||
now = time.time()
|
||||
items = self.redis.zrange('queue', 0, -1)
|
||||
for item in items:
|
||||
item_data = json.loads(item)
|
||||
timestamp = item_data[-1]
|
||||
if now - timestamp > opts.backend_generate_request_timeout * 3: # TODO: config option
|
||||
self.redis.zrem('queue', item)
|
||||
print('removed item from queue:', item)
|
||||
|
||||
|
||||
class DataEvent:
|
||||
def __init__(self, event_id=None):
|
||||
|
@ -112,7 +131,7 @@ def decr_active_workers(selected_model: str, backend_url: str):
|
|||
|
||||
|
||||
class PriorityQueue:
|
||||
def __init__(self, backends: list = None):
|
||||
def __init__(self, backends: set = None):
|
||||
"""
|
||||
Only have to load the backends once.
|
||||
:param backends:
|
||||
|
@ -120,10 +139,10 @@ class PriorityQueue:
|
|||
self.redis = Redis(host='localhost', port=6379, db=9)
|
||||
if backends:
|
||||
for item in backends:
|
||||
self.redis.lpush('backends', item)
|
||||
self.redis.sadd('backends', item)
|
||||
|
||||
def get_backends(self):
|
||||
return [x.decode('utf-8') for x in self.redis.lrange('backends', 0, -1)]
|
||||
return {x.decode('utf-8') for x in self.redis.smembers('backends')}
|
||||
|
||||
def get_queued_ip_count(self, client_ip: str):
|
||||
count = 0
|
||||
|
@ -136,22 +155,32 @@ class PriorityQueue:
|
|||
queue = RedisPriorityQueue(backend_url)
|
||||
return queue.put(item, priority, selected_model)
|
||||
|
||||
def activity(self):
|
||||
lines = []
|
||||
status_redis = RedisCustom('worker_status')
|
||||
for worker in status_redis.keys():
|
||||
lines.append((worker, status_redis.getp(worker)))
|
||||
return sorted(lines)
|
||||
|
||||
def len(self, model_name):
|
||||
count = 0
|
||||
backends_with_models = []
|
||||
backends_with_models = set()
|
||||
for k in self.get_backends():
|
||||
info = cluster_config.get_backend(k)
|
||||
if info.get('model') == model_name:
|
||||
backends_with_models.append(k)
|
||||
backends_with_models.add(k)
|
||||
for backend_url in backends_with_models:
|
||||
count += len(RedisPriorityQueue(backend_url))
|
||||
return count
|
||||
|
||||
def __len__(self):
|
||||
count = 0
|
||||
p = set()
|
||||
for backend_url in self.get_backends():
|
||||
queue = RedisPriorityQueue(backend_url)
|
||||
p.add((backend_url, len(queue)))
|
||||
count += len(queue)
|
||||
print(p)
|
||||
return count
|
||||
|
||||
def flush(self):
|
||||
|
|
|
@ -1,20 +1,48 @@
|
|||
import queue
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from uuid import uuid4
|
||||
|
||||
from redis.client import PubSub
|
||||
|
||||
from llm_server import opts
|
||||
from llm_server.cluster.cluster_config import cluster_config
|
||||
from llm_server.custom_redis import redis
|
||||
from llm_server.custom_redis import RedisCustom, redis
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.routes.queue import DataEvent, RedisPriorityQueue, decr_active_workers, decrement_ip_count, incr_active_workers, increment_ip_count
|
||||
|
||||
|
||||
class ListenerThread(threading.Thread):
|
||||
def __init__(self, pubsub: PubSub, listener_queue: queue.Queue, stop_event: threading.Event):
|
||||
threading.Thread.__init__(self)
|
||||
self.pubsub = pubsub
|
||||
self.listener_queue = listener_queue
|
||||
self.stop_event = stop_event
|
||||
|
||||
def run(self):
|
||||
while not self.stop_event.is_set():
|
||||
message = self.pubsub.get_message()
|
||||
if message:
|
||||
self.listener_queue.put(message)
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def worker(backend_url):
|
||||
queue = RedisPriorityQueue(backend_url)
|
||||
status_redis = RedisCustom('worker_status')
|
||||
worker_id = uuid4()
|
||||
status_redis.setp(str(worker_id), None)
|
||||
redis_queue = RedisPriorityQueue(backend_url)
|
||||
while True:
|
||||
(request_json_body, client_ip, token, parameters), event_id, selected_model = queue.get()
|
||||
(request_json_body, client_ip, token, parameters), event_id, selected_model, timestamp = redis_queue.get()
|
||||
backend_info = cluster_config.get_backend(backend_url)
|
||||
|
||||
pubsub = redis.pubsub()
|
||||
pubsub.subscribe(event_id)
|
||||
stop_event = threading.Event()
|
||||
q = queue.Queue()
|
||||
listener = ListenerThread(pubsub, q, stop_event)
|
||||
listener.start()
|
||||
|
||||
if not backend_info['online']:
|
||||
redis.publish(event_id, 'offline')
|
||||
|
@ -26,6 +54,8 @@ def worker(backend_url):
|
|||
increment_ip_count(client_ip, 'processing_ips')
|
||||
incr_active_workers(selected_model, backend_url)
|
||||
|
||||
status_redis.setp(str(worker_id), (backend_url, client_ip))
|
||||
|
||||
try:
|
||||
if not request_json_body:
|
||||
# This was a dummy request from the streaming handlers.
|
||||
|
@ -34,13 +64,27 @@ def worker(backend_url):
|
|||
# is finished. Since a lot of ratelimiting and stats are
|
||||
# based off the number of active workers, we must keep
|
||||
# the generation based off the workers.
|
||||
start_time = time.time()
|
||||
redis.publish(event_id, 'begin')
|
||||
for item in pubsub.listen():
|
||||
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
|
||||
# The streaming endpoint has said that it has finished
|
||||
while True:
|
||||
status_redis.setp(str(worker_id), (f'waiting for streaming to complete - {time.time() - start_time} - {opts.backend_generate_request_timeout}', client_ip))
|
||||
|
||||
try:
|
||||
item = q.get(timeout=30)
|
||||
except queue.Empty:
|
||||
print('Inferencer timed out waiting for chunk from streamer:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
|
||||
status_redis.setp(str(worker_id), ('streaming chunk timed out', client_ip))
|
||||
break
|
||||
|
||||
if time.time() - start_time >= opts.backend_generate_request_timeout:
|
||||
status_redis.setp(str(worker_id), ('streaming timed out', client_ip))
|
||||
print('Inferencer timed out waiting for streaming to complete:', (request_json_body, client_ip, token, parameters), event_id, selected_model)
|
||||
break
|
||||
if item['type'] == 'message' and item['data'].decode('utf-8') == 'finished':
|
||||
status_redis.setp(str(worker_id), ('streaming completed', client_ip))
|
||||
break
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
status_redis.setp(str(worker_id), ('generating', client_ip))
|
||||
# Normal inference (not streaming).
|
||||
success, response, error_msg = generator(request_json_body, backend_url)
|
||||
event = DataEvent(event_id)
|
||||
|
@ -48,8 +92,11 @@ def worker(backend_url):
|
|||
except:
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
stop_event.set() # make sure to stop the listener thread
|
||||
listener.join()
|
||||
decrement_ip_count(client_ip, 'processing_ips')
|
||||
decr_active_workers(selected_model, backend_url)
|
||||
status_redis.setp(str(worker_id), None)
|
||||
|
||||
|
||||
def start_workers(cluster: dict):
|
||||
|
|
|
@ -7,6 +7,7 @@ from llm_server.cluster.cluster_config import cluster_config, get_backends
|
|||
from llm_server.custom_redis import redis
|
||||
from llm_server.database.database import weighted_average_column_for_model
|
||||
from llm_server.llm.info import get_info
|
||||
from llm_server.routes.queue import RedisPriorityQueue, priority_queue
|
||||
|
||||
|
||||
def main_background_thread():
|
||||
|
@ -35,6 +36,11 @@ def main_background_thread():
|
|||
except Exception as e:
|
||||
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
|
||||
|
||||
backends = priority_queue.get_backends()
|
||||
for backend_url in backends:
|
||||
queue = RedisPriorityQueue(backend_url)
|
||||
queue.cleanup()
|
||||
|
||||
time.sleep(30)
|
||||
|
||||
|
||||
|
|
|
@ -24,5 +24,7 @@ def console_printer():
|
|||
for k in processing:
|
||||
processing_count += redis.get(k, default=0, dtype=int)
|
||||
backends = [k for k, v in cluster_config.all().items() if v['online']]
|
||||
logger.info(f'REQUEST QUEUE -> Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
|
||||
time.sleep(10)
|
||||
activity = priority_queue.activity()
|
||||
print(activity)
|
||||
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
|
||||
time.sleep(1)
|
||||
|
|
Reference in New Issue