remove timed-out items from queue

This commit is contained in:
Cyberes 2023-10-17 11:46:39 -06:00
parent 7998cfca87
commit 2fed87d340
10 changed files with 80 additions and 37 deletions

View File

@ -232,5 +232,22 @@ class RedisCustom(Redis):
def lpop(self, name: str, count: Optional[int] = None):
return self.redis.lpop(self._key(name), count)
def zrange(
self,
name: KeyT,
start: int,
end: int,
desc: bool = False,
withscores: bool = False,
score_cast_func: Union[type, Callable] = float,
byscore: bool = False,
bylex: bool = False,
offset: int = None,
num: int = None,
):
return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num)
def zrem(self, name: KeyT, *values: FieldT):
return self.redis.zrem(self._key(name), *values)
redis = RedisCustom('local_llm')

View File

@ -30,6 +30,7 @@ class VLLMBackend(LLMBackend):
if top_k <= 0:
top_k = -1
# TODO: support more params
sampling_params = SamplingParams(
temperature=parameters.get('temperature', self._default_params['temperature']),
top_p=parameters.get('top_p', self._default_params['top_p']),

View File

@ -1,8 +1,8 @@
import json
import ujson
import time
import traceback
import ujson
from flask import Response, jsonify, request
from redis import Redis
@ -97,8 +97,14 @@ def openai_chat_completions(model_name=None):
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
# Need to do this before we enter generate() since we want to be able to
# return a 408 if necessary.
_, stream_name, error_msg = event.wait()
if error_msg == 'closed':
stream_name = None # set to null so that the Finally ignores it.
return 'Request Timeout', 408
def generate():
stream_name = event.wait()
stream_redis = Redis(db=8)
generated_text = ''
try:
@ -159,6 +165,7 @@ def openai_chat_completions(model_name=None):
finally:
if event:
redis.publish(f'notifications:{event.event_id}', 'canceled')
if stream_name:
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')

View File

@ -143,8 +143,12 @@ def openai_completions(model_name=None):
model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model')
oai_string = generate_oai_string(30)
_, stream_name, error_msg = event.wait()
if error_msg == 'closed':
stream_name = None
return 'Request Timeout', 408
def generate():
stream_name = event.wait()
stream_redis = Redis(db=8)
generated_text = ''
try:
@ -206,6 +210,7 @@ def openai_completions(model_name=None):
finally:
if event:
redis.publish(f'notifications:{event.event_id}', 'canceled')
if stream_name:
stream_redis.delete(stream_name)
return Response(generate(), mimetype='text/event-stream')

View File

@ -57,12 +57,12 @@ class RedisPriorityQueue:
return item
time.sleep(0.1) # wait for something to be added to the queue
def print_all_items(self):
items = self.redis.zrange('queue', 0, -1)
to_print = []
for item in items:
to_print.append(item.decode('utf-8'))
print(f'ITEMS {self.name} -->', to_print)
# def print_all_items(self):
# items = self.redis.zrange('queue', 0, -1)
# to_print = []
# for item in items:
# to_print.append(item.decode('utf-8'))
# print(f'ITEMS {self.name} -->', to_print)
def increment_ip_count(self, client_ip: str, redis_key):
self.redis.hincrby(redis_key, client_ip, 1)
@ -84,15 +84,23 @@ class RedisPriorityQueue:
def flush(self):
self.redis.flush()
def items(self):
return self.redis.zrange('queue', 0, -1)
def cleanup(self):
now = time.time()
items = self.redis.zrange('queue', 0, -1)
for item in items:
for item in self.items():
item_data = json.loads(item)
timestamp = item_data[-1]
if now - timestamp > opts.backend_generate_request_timeout * 3: # TODO: config option
self.redis.zrem('queue', item)
print('removed item from queue:', item)
timestamp = item_data[-2]
if now - timestamp > opts.backend_generate_request_timeout:
self.redis.zrem('queue', 0, item)
data = json.loads(item.decode('utf-8'))
event_id = data[1]
client_ip = data[0][1]
self.decrement_ip_count(client_ip, 'queued_ip_count')
event = DataEvent(event_id)
event.set((False, None, 'closed'))
print('Removed timed-out item from queue:', event_id)
class DataEvent:

View File

@ -148,6 +148,8 @@ class RequestHandler:
# TODO: add wait timeout
success, response, error_msg = event.wait()
if error_msg == 'closed':
return (False, None, None, 0), (self.handle_error('Request Timeout')[0], 408)
end_time = time.time()
elapsed_time = end_time - self.start_time

View File

@ -129,7 +129,11 @@ def do_stream(ws, model_name):
return
event_id = event.event_id
stream_name = event.wait()
_, stream_name, error_msg = event.wait()
if error_msg == 'closed':
ws.close(reason=1014, message='Request Timeout')
return
stream_redis = Redis(db=8)
generated_text = ''
@ -170,6 +174,7 @@ def do_stream(ws, model_name):
except:
# The client closed the stream.
pass
if stream_name:
stream_redis.delete(stream_name)
end_time = time.time()
elapsed_time = end_time - start_time

View File

@ -105,7 +105,7 @@ def worker(backend_url):
if do_stream:
# Return the name of the stream that the slave should connect to.
event = DataEvent(event_id)
event.set(get_stream_name(worker_id))
event.set((True, get_stream_name(worker_id), None))
msg_to_backend = {
**parameters,

View File

@ -36,7 +36,6 @@ def main_background_thread():
except Exception as e:
print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}')
# TODO: test
backends = priority_queue.get_backends()
for backend_url in backends:
queue = RedisPriorityQueue(backend_url)

View File

@ -26,7 +26,6 @@ def console_printer():
backends = [k for k, v in cluster_config.all().items() if v['online']]
activity = priority_queue.activity()
# TODO: Active Workers and Processing should read the same. If not, that's an issue
# Active Workers and Processing should read the same. If not, that's an issue.
logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}')
time.sleep(2)
time.sleep(10)