diff --git a/llm_server/custom_redis.py b/llm_server/custom_redis.py index 60e4dbd..886230a 100644 --- a/llm_server/custom_redis.py +++ b/llm_server/custom_redis.py @@ -232,5 +232,22 @@ class RedisCustom(Redis): def lpop(self, name: str, count: Optional[int] = None): return self.redis.lpop(self._key(name), count) + def zrange( + self, + name: KeyT, + start: int, + end: int, + desc: bool = False, + withscores: bool = False, + score_cast_func: Union[type, Callable] = float, + byscore: bool = False, + bylex: bool = False, + offset: int = None, + num: int = None, + ): + return self.redis.zrange(self._key(name), start, end, desc, withscores, score_cast_func, byscore, bylex, offset, num) + + def zrem(self, name: KeyT, *values: FieldT): + return self.redis.zrem(self._key(name), *values) redis = RedisCustom('local_llm') diff --git a/llm_server/llm/vllm/vllm_backend.py b/llm_server/llm/vllm/vllm_backend.py index 9665547..5c12b45 100644 --- a/llm_server/llm/vllm/vllm_backend.py +++ b/llm_server/llm/vllm/vllm_backend.py @@ -30,6 +30,7 @@ class VLLMBackend(LLMBackend): if top_k <= 0: top_k = -1 + # TODO: support more params sampling_params = SamplingParams( temperature=parameters.get('temperature', self._default_params['temperature']), top_p=parameters.get('top_p', self._default_params['top_p']), diff --git a/llm_server/routes/openai/chat_completions.py b/llm_server/routes/openai/chat_completions.py index fbac971..475ff00 100644 --- a/llm_server/routes/openai/chat_completions.py +++ b/llm_server/routes/openai/chat_completions.py @@ -1,8 +1,8 @@ import json -import ujson import time import traceback +import ujson from flask import Response, jsonify, request from redis import Redis @@ -97,8 +97,14 @@ def openai_chat_completions(model_name=None): model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) + # Need to do this before we enter generate() since we want to be able to + # return a 408 if necessary. + _, stream_name, error_msg = event.wait() + if error_msg == 'closed': + stream_name = None # set to null so that the Finally ignores it. + return 'Request Timeout', 408 + def generate(): - stream_name = event.wait() stream_redis = Redis(db=8) generated_text = '' try: @@ -159,7 +165,8 @@ def openai_chat_completions(model_name=None): finally: if event: redis.publish(f'notifications:{event.event_id}', 'canceled') - stream_redis.delete(stream_name) + if stream_name: + stream_redis.delete(stream_name) return Response(generate(), mimetype='text/event-stream') except Exception: diff --git a/llm_server/routes/openai/completions.py b/llm_server/routes/openai/completions.py index e7b85ea..5dfacf3 100644 --- a/llm_server/routes/openai/completions.py +++ b/llm_server/routes/openai/completions.py @@ -143,8 +143,12 @@ def openai_completions(model_name=None): model = redis.get('running_model', 'ERROR', dtype=str) if opts.openai_expose_our_model else request_json_body.get('model') oai_string = generate_oai_string(30) + _, stream_name, error_msg = event.wait() + if error_msg == 'closed': + stream_name = None + return 'Request Timeout', 408 + def generate(): - stream_name = event.wait() stream_redis = Redis(db=8) generated_text = '' try: @@ -206,7 +210,8 @@ def openai_completions(model_name=None): finally: if event: redis.publish(f'notifications:{event.event_id}', 'canceled') - stream_redis.delete(stream_name) + if stream_name: + stream_redis.delete(stream_name) return Response(generate(), mimetype='text/event-stream') except Exception: diff --git a/llm_server/routes/queue.py b/llm_server/routes/queue.py index cb4aaf5..834c844 100644 --- a/llm_server/routes/queue.py +++ b/llm_server/routes/queue.py @@ -57,12 +57,12 @@ class RedisPriorityQueue: return item time.sleep(0.1) # wait for something to be added to the queue - def print_all_items(self): - items = self.redis.zrange('queue', 0, -1) - to_print = [] - for item in items: - to_print.append(item.decode('utf-8')) - print(f'ITEMS {self.name} -->', to_print) + # def print_all_items(self): + # items = self.redis.zrange('queue', 0, -1) + # to_print = [] + # for item in items: + # to_print.append(item.decode('utf-8')) + # print(f'ITEMS {self.name} -->', to_print) def increment_ip_count(self, client_ip: str, redis_key): self.redis.hincrby(redis_key, client_ip, 1) @@ -84,15 +84,23 @@ class RedisPriorityQueue: def flush(self): self.redis.flush() + def items(self): + return self.redis.zrange('queue', 0, -1) + def cleanup(self): now = time.time() - items = self.redis.zrange('queue', 0, -1) - for item in items: + for item in self.items(): item_data = json.loads(item) - timestamp = item_data[-1] - if now - timestamp > opts.backend_generate_request_timeout * 3: # TODO: config option - self.redis.zrem('queue', item) - print('removed item from queue:', item) + timestamp = item_data[-2] + if now - timestamp > opts.backend_generate_request_timeout: + self.redis.zrem('queue', 0, item) + data = json.loads(item.decode('utf-8')) + event_id = data[1] + client_ip = data[0][1] + self.decrement_ip_count(client_ip, 'queued_ip_count') + event = DataEvent(event_id) + event.set((False, None, 'closed')) + print('Removed timed-out item from queue:', event_id) class DataEvent: diff --git a/llm_server/routes/request_handler.py b/llm_server/routes/request_handler.py index 0fe94ec..df07c29 100644 --- a/llm_server/routes/request_handler.py +++ b/llm_server/routes/request_handler.py @@ -148,6 +148,8 @@ class RequestHandler: # TODO: add wait timeout success, response, error_msg = event.wait() + if error_msg == 'closed': + return (False, None, None, 0), (self.handle_error('Request Timeout')[0], 408) end_time = time.time() elapsed_time = end_time - self.start_time diff --git a/llm_server/routes/v1/generate_stream.py b/llm_server/routes/v1/generate_stream.py index 72f4bad..cdf939d 100644 --- a/llm_server/routes/v1/generate_stream.py +++ b/llm_server/routes/v1/generate_stream.py @@ -129,7 +129,11 @@ def do_stream(ws, model_name): return event_id = event.event_id - stream_name = event.wait() + _, stream_name, error_msg = event.wait() + if error_msg == 'closed': + ws.close(reason=1014, message='Request Timeout') + return + stream_redis = Redis(db=8) generated_text = '' @@ -170,20 +174,21 @@ def do_stream(ws, model_name): except: # The client closed the stream. pass - stream_redis.delete(stream_name) - end_time = time.time() - elapsed_time = end_time - start_time - log_to_db(ip=handler.client_ip, - token=handler.token, - prompt=input_prompt, - response=generated_text, - gen_time=elapsed_time, - parameters=handler.parameters, - headers=r_headers, - backend_response_code=response_status_code, - request_url=r_url, - backend_url=handler.backend_url - ) + if stream_name: + stream_redis.delete(stream_name) + end_time = time.time() + elapsed_time = end_time - start_time + log_to_db(ip=handler.client_ip, + token=handler.token, + prompt=input_prompt, + response=generated_text, + gen_time=elapsed_time, + parameters=handler.parameters, + headers=r_headers, + backend_response_code=response_status_code, + request_url=r_url, + backend_url=handler.backend_url + ) finally: if event_id: redis.publish(f'notifications:{event_id}', 'canceled') diff --git a/llm_server/workers/inferencer.py b/llm_server/workers/inferencer.py index d1a3ceb..0738c1b 100644 --- a/llm_server/workers/inferencer.py +++ b/llm_server/workers/inferencer.py @@ -105,7 +105,7 @@ def worker(backend_url): if do_stream: # Return the name of the stream that the slave should connect to. event = DataEvent(event_id) - event.set(get_stream_name(worker_id)) + event.set((True, get_stream_name(worker_id), None)) msg_to_backend = { **parameters, diff --git a/llm_server/workers/mainer.py b/llm_server/workers/mainer.py index e06e803..d342f4b 100644 --- a/llm_server/workers/mainer.py +++ b/llm_server/workers/mainer.py @@ -36,7 +36,6 @@ def main_background_thread(): except Exception as e: print(f'Failed fetch the homepage - {e.__class__.__name__}: {e}') - # TODO: test backends = priority_queue.get_backends() for backend_url in backends: queue = RedisPriorityQueue(backend_url) diff --git a/llm_server/workers/printer.py b/llm_server/workers/printer.py index 4025df3..fe0d129 100644 --- a/llm_server/workers/printer.py +++ b/llm_server/workers/printer.py @@ -26,7 +26,6 @@ def console_printer(): backends = [k for k, v in cluster_config.all().items() if v['online']] activity = priority_queue.activity() - # TODO: Active Workers and Processing should read the same. If not, that's an issue - + # Active Workers and Processing should read the same. If not, that's an issue. logger.info(f'REQUEST QUEUE -> Active Workers: {len([i for i in activity if i[1]])} | Processing: {processing_count} | Queued: {len(priority_queue)} | Backends Online: {len(backends)}') - time.sleep(2) + time.sleep(10)