clean up background threads
This commit is contained in:
parent
35e9847b27
commit
43299b32ad
|
@ -31,7 +31,8 @@ config_default_vars = {
|
|||
'openai_moderation_workers': 10,
|
||||
'openai_org_name': 'OpenAI',
|
||||
'openai_silent_trim': False,
|
||||
'openai_moderation_enabled': True
|
||||
'openai_moderation_enabled': True,
|
||||
'netdata_root': None
|
||||
}
|
||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import redis as redis_pkg
|
|||
import simplejson as json
|
||||
from flask_caching import Cache
|
||||
from redis import Redis
|
||||
from redis.typing import FieldT, ExpiryT
|
||||
from redis.typing import ExpiryT, FieldT
|
||||
|
||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||
|
||||
|
@ -72,6 +72,20 @@ class RedisWrapper:
|
|||
def sismember(self, key: str, value: str):
|
||||
return self.redis.sismember(self._key(key), value)
|
||||
|
||||
def lindex(
|
||||
self, name: str, index: int
|
||||
):
|
||||
return self.redis.lindex(self._key(name), index)
|
||||
|
||||
def lrem(self, name: str, count: int, value: str):
|
||||
return self.redis.lrem(self._key(name), count, value)
|
||||
|
||||
def rpush(self, name: str, *values: FieldT):
|
||||
return self.redis.rpush(self._key(name), *values)
|
||||
|
||||
def llen(self, name: str):
|
||||
return self.redis.llen(self._key(name))
|
||||
|
||||
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
|
||||
return self.set(key, json.dumps(dict_value), ex=ex)
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler):
|
|||
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||
backend_response = self.handle_error(msg)
|
||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
|
||||
return backend_response[0], 429 # We only return the response from handle_error(), not the error code
|
||||
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
|
||||
|
||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||
|
|
|
@ -9,9 +9,6 @@ from redis import Redis
|
|||
from llm_server import opts
|
||||
from llm_server.llm.generator import generator
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
|
||||
|
||||
redis.set_dict('processing_ips', {})
|
||||
|
||||
|
||||
def increment_ip_count(client_ip: int, redis_key):
|
||||
|
@ -33,28 +30,21 @@ def decrement_ip_count(client_ip: int, redis_key):
|
|||
|
||||
class RedisPriorityQueue:
|
||||
def __init__(self):
|
||||
self._index = 0
|
||||
self._lock = threading.Lock()
|
||||
self.redis = Redis(host='localhost', port=6379, db=15)
|
||||
|
||||
# Clear the DB
|
||||
for key in self.redis.scan_iter('*'):
|
||||
self.redis.delete(key)
|
||||
|
||||
self.pubsub = self.redis.pubsub()
|
||||
self.pubsub.subscribe('events')
|
||||
|
||||
def put(self, item, priority):
|
||||
event = DataEvent()
|
||||
|
||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||
ip_count = self.redis.hget('queued_ip_count', item[1])
|
||||
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
||||
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
||||
return None # reject the request
|
||||
self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority})
|
||||
self._index += 1
|
||||
# Increment the count for this IP
|
||||
with self._lock:
|
||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||
|
||||
self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority})
|
||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||
return event
|
||||
|
||||
def get(self):
|
||||
|
@ -64,8 +54,7 @@ class RedisPriorityQueue:
|
|||
item = json.loads(data[0][0])
|
||||
client_ip = item[1][1]
|
||||
# Decrement the count for this IP
|
||||
with self._lock:
|
||||
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
||||
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
||||
return item
|
||||
time.sleep(1) # wait for an item to be added to the queue
|
||||
|
||||
|
@ -100,11 +89,9 @@ priority_queue = RedisPriorityQueue()
|
|||
|
||||
def worker():
|
||||
while True:
|
||||
index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
||||
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
||||
|
||||
increment_ip_count(client_ip, 'processing_ips')
|
||||
|
||||
# TODO: only increment if not valid SYSTEM__ token
|
||||
redis.incr('active_gen_workers')
|
||||
|
||||
try:
|
||||
|
@ -113,15 +100,12 @@ def worker():
|
|||
end_time = time.time()
|
||||
|
||||
elapsed_time = end_time - start_time
|
||||
with generation_elapsed_lock:
|
||||
generation_elapsed.append((end_time, elapsed_time))
|
||||
redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time)))
|
||||
|
||||
event = DataEvent(event_id)
|
||||
event.set((success, response, error_msg))
|
||||
finally:
|
||||
decrement_ip_count(client_ip, 'processing_ips')
|
||||
|
||||
# TODO: only decrement if not valid SYSTEM__ token
|
||||
redis.decr('active_gen_workers')
|
||||
|
||||
|
||||
|
|
|
@ -193,6 +193,10 @@ class RequestHandler:
|
|||
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||
|
||||
def is_client_ratelimited(self) -> bool:
|
||||
print('queued_ip_count', redis.get_dict('queued_ip_count'))
|
||||
print('processing_ips', redis.get_dict('processing_ips'))
|
||||
|
||||
|
||||
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
||||
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
|
||||
return False
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import time
|
||||
from datetime import datetime
|
||||
from threading import Lock, Thread
|
||||
from threading import Thread
|
||||
|
||||
from llm_server.routes.cache import redis
|
||||
|
||||
|
@ -9,14 +9,6 @@ from llm_server.routes.cache import redis
|
|||
|
||||
server_start_time = datetime.now()
|
||||
|
||||
# TODO: have a background thread put the averages in a variable so we don't end up with massive arrays
|
||||
|
||||
# wait_in_queue_elapsed = []
|
||||
# wait_in_queue_elapsed_lock = Lock()
|
||||
|
||||
generation_elapsed = []
|
||||
generation_elapsed_lock = Lock()
|
||||
|
||||
|
||||
# TODO: do I need this?
|
||||
# def elapsed_times_cleanup():
|
||||
|
@ -30,8 +22,6 @@ generation_elapsed_lock = Lock()
|
|||
|
||||
|
||||
def calculate_avg_gen_time():
|
||||
# TODO: calculate the average from the database. Have this be set by an option in the config
|
||||
|
||||
# Get the average generation time from Redis
|
||||
average_generation_time = redis.get('average_generation_time')
|
||||
if average_generation_time is None:
|
||||
|
@ -40,30 +30,6 @@ def calculate_avg_gen_time():
|
|||
return float(average_generation_time)
|
||||
|
||||
|
||||
def process_avg_gen_time():
|
||||
global generation_elapsed
|
||||
while True:
|
||||
with generation_elapsed_lock:
|
||||
# Get the current time
|
||||
current_time = time.time()
|
||||
|
||||
# Remove data older than 3 minutes
|
||||
three_minutes_ago = current_time - 180
|
||||
generation_elapsed[:] = [(end, elapsed) for end, elapsed in generation_elapsed if end >= three_minutes_ago]
|
||||
|
||||
# Get the data from the last minute
|
||||
one_minute_ago = current_time - 60
|
||||
recent_data = [elapsed for end, elapsed in generation_elapsed if end >= one_minute_ago]
|
||||
|
||||
# Calculate the average
|
||||
if len(recent_data) == 0:
|
||||
average_generation_time = 0
|
||||
else:
|
||||
average_generation_time = sum(recent_data) / len(recent_data)
|
||||
redis.set('average_generation_time', average_generation_time)
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def get_total_proompts():
|
||||
count = redis.get('proompts')
|
||||
if count is None:
|
||||
|
|
|
@ -8,7 +8,7 @@ from llm_server.llm.info import get_running_model
|
|||
from llm_server.netdata import get_power_states
|
||||
from llm_server.routes.cache import redis
|
||||
from llm_server.routes.queue import priority_queue
|
||||
from llm_server.routes.stats import calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time
|
||||
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
|
||||
|
||||
|
||||
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
|
||||
|
@ -61,22 +61,8 @@ def generate_stats(regen: bool = False):
|
|||
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
|
||||
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
|
||||
|
||||
if opts.average_generation_time_mode == 'database':
|
||||
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
|
||||
|
||||
# What to use in our math that calculates the wait time.
|
||||
# We could use the average TPS but we don't know the exact TPS value, only
|
||||
# the backend knows that. So, let's just stick with the elapsed time.
|
||||
gen_time_calc = average_generation_time
|
||||
|
||||
estimated_wait_sec = calculate_wait_time(gen_time_calc, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
|
||||
|
||||
elif opts.average_generation_time_mode == 'minute':
|
||||
average_generation_time = calculate_avg_gen_time()
|
||||
gen_time_calc = average_generation_time
|
||||
estimated_wait_sec = ((gen_time_calc * proompters_in_queue) / opts.concurrent_gens) + (active_gen_workers * gen_time_calc)
|
||||
else:
|
||||
raise Exception
|
||||
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
|
||||
estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
|
||||
|
||||
if opts.netdata_root:
|
||||
netdata_stats = {}
|
||||
|
@ -100,7 +86,7 @@ def generate_stats(regen: bool = False):
|
|||
},
|
||||
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
|
||||
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
|
||||
'average_generation_elapsed_sec': int(gen_time_calc),
|
||||
'average_generation_elapsed_sec': int(average_generation_time),
|
||||
# 'estimated_avg_tps': estimated_avg_tps,
|
||||
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
|
||||
'nvidia': netdata_stats
|
||||
|
|
15
server.py
15
server.py
|
@ -1,3 +1,5 @@
|
|||
from redis import Redis
|
||||
|
||||
try:
|
||||
import gevent.monkey
|
||||
|
||||
|
@ -32,6 +34,8 @@ from llm_server.stream import init_socketio
|
|||
# TODO: allow setting concurrent gens per-backend
|
||||
# TODO: set the max tokens to that of the lowest backend
|
||||
# TODO: implement RRD backend loadbalancer option
|
||||
# TODO: have VLLM reject a request if it already has n == concurrent_gens running
|
||||
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
||||
|
||||
# Lower priority
|
||||
# TODO: the processing stat showed -1 and I had to restart the server
|
||||
|
@ -62,7 +66,7 @@ from llm_server.llm.vllm.info import vllm_info
|
|||
from llm_server.routes.cache import RedisWrapper, flask_cache
|
||||
from llm_server.llm import redis
|
||||
from llm_server.routes.queue import start_workers
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
|
||||
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers
|
||||
from llm_server.routes.v1.generate_stats import generate_stats
|
||||
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
|
||||
|
||||
|
@ -160,6 +164,12 @@ def pre_fork(server):
|
|||
flushed_keys = redis.flush()
|
||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||
|
||||
redis.set_dict('processing_ips', {})
|
||||
redis.set_dict('queued_ip_count', {})
|
||||
queue_redis = Redis(host='localhost', port=6379, db=15)
|
||||
for key in queue_redis.scan_iter('*'):
|
||||
queue_redis.delete(key)
|
||||
|
||||
redis.set('backend_mode', opts.mode)
|
||||
if config['http_host']:
|
||||
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
||||
|
@ -174,9 +184,6 @@ def pre_fork(server):
|
|||
print(f'Started {opts.concurrent_gens} inference workers.')
|
||||
|
||||
start_moderation_workers(opts.openai_moderation_workers)
|
||||
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
||||
process_avg_gen_time_background_thread.daemon = True
|
||||
process_avg_gen_time_background_thread.start()
|
||||
MainBackgroundThread().start()
|
||||
SemaphoreCheckerThread().start()
|
||||
|
||||
|
|
Reference in New Issue