clean up background threads
This commit is contained in:
parent
35e9847b27
commit
43299b32ad
|
@ -31,7 +31,8 @@ config_default_vars = {
|
||||||
'openai_moderation_workers': 10,
|
'openai_moderation_workers': 10,
|
||||||
'openai_org_name': 'OpenAI',
|
'openai_org_name': 'OpenAI',
|
||||||
'openai_silent_trim': False,
|
'openai_silent_trim': False,
|
||||||
'openai_moderation_enabled': True
|
'openai_moderation_enabled': True,
|
||||||
|
'netdata_root': None
|
||||||
}
|
}
|
||||||
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import redis as redis_pkg
|
||||||
import simplejson as json
|
import simplejson as json
|
||||||
from flask_caching import Cache
|
from flask_caching import Cache
|
||||||
from redis import Redis
|
from redis import Redis
|
||||||
from redis.typing import FieldT, ExpiryT
|
from redis.typing import ExpiryT, FieldT
|
||||||
|
|
||||||
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
|
||||||
|
|
||||||
|
@ -72,6 +72,20 @@ class RedisWrapper:
|
||||||
def sismember(self, key: str, value: str):
|
def sismember(self, key: str, value: str):
|
||||||
return self.redis.sismember(self._key(key), value)
|
return self.redis.sismember(self._key(key), value)
|
||||||
|
|
||||||
|
def lindex(
|
||||||
|
self, name: str, index: int
|
||||||
|
):
|
||||||
|
return self.redis.lindex(self._key(name), index)
|
||||||
|
|
||||||
|
def lrem(self, name: str, count: int, value: str):
|
||||||
|
return self.redis.lrem(self._key(name), count, value)
|
||||||
|
|
||||||
|
def rpush(self, name: str, *values: FieldT):
|
||||||
|
return self.redis.rpush(self._key(name), *values)
|
||||||
|
|
||||||
|
def llen(self, name: str):
|
||||||
|
return self.redis.llen(self._key(name))
|
||||||
|
|
||||||
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
|
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
|
||||||
return self.set(key, json.dumps(dict_value), ex=ex)
|
return self.set(key, json.dumps(dict_value), ex=ex)
|
||||||
|
|
||||||
|
|
|
@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler):
|
||||||
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
|
||||||
backend_response = self.handle_error(msg)
|
backend_response = self.handle_error(msg)
|
||||||
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
|
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
|
||||||
return backend_response[0], 429 # We only return the response from handle_error(), not the error code
|
return backend_response[0], 200 # We only return the response from handle_error(), not the error code
|
||||||
|
|
||||||
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
|
||||||
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'
|
||||||
|
|
|
@ -9,9 +9,6 @@ from redis import Redis
|
||||||
from llm_server import opts
|
from llm_server import opts
|
||||||
from llm_server.llm.generator import generator
|
from llm_server.llm.generator import generator
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
|
|
||||||
|
|
||||||
redis.set_dict('processing_ips', {})
|
|
||||||
|
|
||||||
|
|
||||||
def increment_ip_count(client_ip: int, redis_key):
|
def increment_ip_count(client_ip: int, redis_key):
|
||||||
|
@ -33,27 +30,20 @@ def decrement_ip_count(client_ip: int, redis_key):
|
||||||
|
|
||||||
class RedisPriorityQueue:
|
class RedisPriorityQueue:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._index = 0
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self.redis = Redis(host='localhost', port=6379, db=15)
|
self.redis = Redis(host='localhost', port=6379, db=15)
|
||||||
|
|
||||||
# Clear the DB
|
|
||||||
for key in self.redis.scan_iter('*'):
|
|
||||||
self.redis.delete(key)
|
|
||||||
|
|
||||||
self.pubsub = self.redis.pubsub()
|
self.pubsub = self.redis.pubsub()
|
||||||
self.pubsub.subscribe('events')
|
self.pubsub.subscribe('events')
|
||||||
|
|
||||||
def put(self, item, priority):
|
def put(self, item, priority):
|
||||||
event = DataEvent()
|
event = DataEvent()
|
||||||
|
|
||||||
# Check if the IP is already in the dictionary and if it has reached the limit
|
# Check if the IP is already in the dictionary and if it has reached the limit
|
||||||
ip_count = self.redis.hget('queued_ip_count', item[1])
|
ip_count = self.redis.hget('queued_ip_count', item[1])
|
||||||
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
|
||||||
|
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
|
||||||
return None # reject the request
|
return None # reject the request
|
||||||
self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority})
|
|
||||||
self._index += 1
|
self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority})
|
||||||
# Increment the count for this IP
|
|
||||||
with self._lock:
|
|
||||||
self.increment_ip_count(item[1], 'queued_ip_count')
|
self.increment_ip_count(item[1], 'queued_ip_count')
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
@ -64,7 +54,6 @@ class RedisPriorityQueue:
|
||||||
item = json.loads(data[0][0])
|
item = json.loads(data[0][0])
|
||||||
client_ip = item[1][1]
|
client_ip = item[1][1]
|
||||||
# Decrement the count for this IP
|
# Decrement the count for this IP
|
||||||
with self._lock:
|
|
||||||
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
self.decrement_ip_count(client_ip, 'queued_ip_count')
|
||||||
return item
|
return item
|
||||||
time.sleep(1) # wait for an item to be added to the queue
|
time.sleep(1) # wait for an item to be added to the queue
|
||||||
|
@ -100,11 +89,9 @@ priority_queue = RedisPriorityQueue()
|
||||||
|
|
||||||
def worker():
|
def worker():
|
||||||
while True:
|
while True:
|
||||||
index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
(request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
|
||||||
|
|
||||||
increment_ip_count(client_ip, 'processing_ips')
|
increment_ip_count(client_ip, 'processing_ips')
|
||||||
|
|
||||||
# TODO: only increment if not valid SYSTEM__ token
|
|
||||||
redis.incr('active_gen_workers')
|
redis.incr('active_gen_workers')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -113,15 +100,12 @@ def worker():
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
elapsed_time = end_time - start_time
|
elapsed_time = end_time - start_time
|
||||||
with generation_elapsed_lock:
|
redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time)))
|
||||||
generation_elapsed.append((end_time, elapsed_time))
|
|
||||||
|
|
||||||
event = DataEvent(event_id)
|
event = DataEvent(event_id)
|
||||||
event.set((success, response, error_msg))
|
event.set((success, response, error_msg))
|
||||||
finally:
|
finally:
|
||||||
decrement_ip_count(client_ip, 'processing_ips')
|
decrement_ip_count(client_ip, 'processing_ips')
|
||||||
|
|
||||||
# TODO: only decrement if not valid SYSTEM__ token
|
|
||||||
redis.decr('active_gen_workers')
|
redis.decr('active_gen_workers')
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -193,6 +193,10 @@ class RequestHandler:
|
||||||
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
|
||||||
|
|
||||||
def is_client_ratelimited(self) -> bool:
|
def is_client_ratelimited(self) -> bool:
|
||||||
|
print('queued_ip_count', redis.get_dict('queued_ip_count'))
|
||||||
|
print('processing_ips', redis.get_dict('processing_ips'))
|
||||||
|
|
||||||
|
|
||||||
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
|
||||||
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
|
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from threading import Lock, Thread
|
from threading import Thread
|
||||||
|
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
|
|
||||||
|
@ -9,14 +9,6 @@ from llm_server.routes.cache import redis
|
||||||
|
|
||||||
server_start_time = datetime.now()
|
server_start_time = datetime.now()
|
||||||
|
|
||||||
# TODO: have a background thread put the averages in a variable so we don't end up with massive arrays
|
|
||||||
|
|
||||||
# wait_in_queue_elapsed = []
|
|
||||||
# wait_in_queue_elapsed_lock = Lock()
|
|
||||||
|
|
||||||
generation_elapsed = []
|
|
||||||
generation_elapsed_lock = Lock()
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: do I need this?
|
# TODO: do I need this?
|
||||||
# def elapsed_times_cleanup():
|
# def elapsed_times_cleanup():
|
||||||
|
@ -30,8 +22,6 @@ generation_elapsed_lock = Lock()
|
||||||
|
|
||||||
|
|
||||||
def calculate_avg_gen_time():
|
def calculate_avg_gen_time():
|
||||||
# TODO: calculate the average from the database. Have this be set by an option in the config
|
|
||||||
|
|
||||||
# Get the average generation time from Redis
|
# Get the average generation time from Redis
|
||||||
average_generation_time = redis.get('average_generation_time')
|
average_generation_time = redis.get('average_generation_time')
|
||||||
if average_generation_time is None:
|
if average_generation_time is None:
|
||||||
|
@ -40,30 +30,6 @@ def calculate_avg_gen_time():
|
||||||
return float(average_generation_time)
|
return float(average_generation_time)
|
||||||
|
|
||||||
|
|
||||||
def process_avg_gen_time():
|
|
||||||
global generation_elapsed
|
|
||||||
while True:
|
|
||||||
with generation_elapsed_lock:
|
|
||||||
# Get the current time
|
|
||||||
current_time = time.time()
|
|
||||||
|
|
||||||
# Remove data older than 3 minutes
|
|
||||||
three_minutes_ago = current_time - 180
|
|
||||||
generation_elapsed[:] = [(end, elapsed) for end, elapsed in generation_elapsed if end >= three_minutes_ago]
|
|
||||||
|
|
||||||
# Get the data from the last minute
|
|
||||||
one_minute_ago = current_time - 60
|
|
||||||
recent_data = [elapsed for end, elapsed in generation_elapsed if end >= one_minute_ago]
|
|
||||||
|
|
||||||
# Calculate the average
|
|
||||||
if len(recent_data) == 0:
|
|
||||||
average_generation_time = 0
|
|
||||||
else:
|
|
||||||
average_generation_time = sum(recent_data) / len(recent_data)
|
|
||||||
redis.set('average_generation_time', average_generation_time)
|
|
||||||
time.sleep(5)
|
|
||||||
|
|
||||||
|
|
||||||
def get_total_proompts():
|
def get_total_proompts():
|
||||||
count = redis.get('proompts')
|
count = redis.get('proompts')
|
||||||
if count is None:
|
if count is None:
|
||||||
|
|
|
@ -8,7 +8,7 @@ from llm_server.llm.info import get_running_model
|
||||||
from llm_server.netdata import get_power_states
|
from llm_server.netdata import get_power_states
|
||||||
from llm_server.routes.cache import redis
|
from llm_server.routes.cache import redis
|
||||||
from llm_server.routes.queue import priority_queue
|
from llm_server.routes.queue import priority_queue
|
||||||
from llm_server.routes.stats import calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time
|
from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
|
||||||
|
|
||||||
|
|
||||||
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
|
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
|
||||||
|
@ -61,22 +61,8 @@ def generate_stats(regen: bool = False):
|
||||||
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
|
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
|
||||||
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
|
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
|
||||||
|
|
||||||
if opts.average_generation_time_mode == 'database':
|
|
||||||
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
|
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
|
||||||
|
estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
|
||||||
# What to use in our math that calculates the wait time.
|
|
||||||
# We could use the average TPS but we don't know the exact TPS value, only
|
|
||||||
# the backend knows that. So, let's just stick with the elapsed time.
|
|
||||||
gen_time_calc = average_generation_time
|
|
||||||
|
|
||||||
estimated_wait_sec = calculate_wait_time(gen_time_calc, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
|
|
||||||
|
|
||||||
elif opts.average_generation_time_mode == 'minute':
|
|
||||||
average_generation_time = calculate_avg_gen_time()
|
|
||||||
gen_time_calc = average_generation_time
|
|
||||||
estimated_wait_sec = ((gen_time_calc * proompters_in_queue) / opts.concurrent_gens) + (active_gen_workers * gen_time_calc)
|
|
||||||
else:
|
|
||||||
raise Exception
|
|
||||||
|
|
||||||
if opts.netdata_root:
|
if opts.netdata_root:
|
||||||
netdata_stats = {}
|
netdata_stats = {}
|
||||||
|
@ -100,7 +86,7 @@ def generate_stats(regen: bool = False):
|
||||||
},
|
},
|
||||||
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
|
'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
|
||||||
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
|
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
|
||||||
'average_generation_elapsed_sec': int(gen_time_calc),
|
'average_generation_elapsed_sec': int(average_generation_time),
|
||||||
# 'estimated_avg_tps': estimated_avg_tps,
|
# 'estimated_avg_tps': estimated_avg_tps,
|
||||||
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
|
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
|
||||||
'nvidia': netdata_stats
|
'nvidia': netdata_stats
|
||||||
|
|
15
server.py
15
server.py
|
@ -1,3 +1,5 @@
|
||||||
|
from redis import Redis
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import gevent.monkey
|
import gevent.monkey
|
||||||
|
|
||||||
|
@ -32,6 +34,8 @@ from llm_server.stream import init_socketio
|
||||||
# TODO: allow setting concurrent gens per-backend
|
# TODO: allow setting concurrent gens per-backend
|
||||||
# TODO: set the max tokens to that of the lowest backend
|
# TODO: set the max tokens to that of the lowest backend
|
||||||
# TODO: implement RRD backend loadbalancer option
|
# TODO: implement RRD backend loadbalancer option
|
||||||
|
# TODO: have VLLM reject a request if it already has n == concurrent_gens running
|
||||||
|
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
|
||||||
|
|
||||||
# Lower priority
|
# Lower priority
|
||||||
# TODO: the processing stat showed -1 and I had to restart the server
|
# TODO: the processing stat showed -1 and I had to restart the server
|
||||||
|
@ -62,7 +66,7 @@ from llm_server.llm.vllm.info import vllm_info
|
||||||
from llm_server.routes.cache import RedisWrapper, flask_cache
|
from llm_server.routes.cache import RedisWrapper, flask_cache
|
||||||
from llm_server.llm import redis
|
from llm_server.llm import redis
|
||||||
from llm_server.routes.queue import start_workers
|
from llm_server.routes.queue import start_workers
|
||||||
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time
|
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers
|
||||||
from llm_server.routes.v1.generate_stats import generate_stats
|
from llm_server.routes.v1.generate_stats import generate_stats
|
||||||
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
|
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
|
||||||
|
|
||||||
|
@ -160,6 +164,12 @@ def pre_fork(server):
|
||||||
flushed_keys = redis.flush()
|
flushed_keys = redis.flush()
|
||||||
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
print('Flushed', len(flushed_keys), 'keys from Redis.')
|
||||||
|
|
||||||
|
redis.set_dict('processing_ips', {})
|
||||||
|
redis.set_dict('queued_ip_count', {})
|
||||||
|
queue_redis = Redis(host='localhost', port=6379, db=15)
|
||||||
|
for key in queue_redis.scan_iter('*'):
|
||||||
|
queue_redis.delete(key)
|
||||||
|
|
||||||
redis.set('backend_mode', opts.mode)
|
redis.set('backend_mode', opts.mode)
|
||||||
if config['http_host']:
|
if config['http_host']:
|
||||||
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
|
||||||
|
@ -174,9 +184,6 @@ def pre_fork(server):
|
||||||
print(f'Started {opts.concurrent_gens} inference workers.')
|
print(f'Started {opts.concurrent_gens} inference workers.')
|
||||||
|
|
||||||
start_moderation_workers(opts.openai_moderation_workers)
|
start_moderation_workers(opts.openai_moderation_workers)
|
||||||
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
|
|
||||||
process_avg_gen_time_background_thread.daemon = True
|
|
||||||
process_avg_gen_time_background_thread.start()
|
|
||||||
MainBackgroundThread().start()
|
MainBackgroundThread().start()
|
||||||
SemaphoreCheckerThread().start()
|
SemaphoreCheckerThread().start()
|
||||||
|
|
||||||
|
|
Reference in New Issue