clean up background threads

This commit is contained in:
Cyberes 2023-09-27 19:39:04 -06:00
parent 35e9847b27
commit 43299b32ad
8 changed files with 46 additions and 84 deletions

View File

@ -31,7 +31,8 @@ config_default_vars = {
'openai_moderation_workers': 10, 'openai_moderation_workers': 10,
'openai_org_name': 'OpenAI', 'openai_org_name': 'OpenAI',
'openai_silent_trim': False, 'openai_silent_trim': False,
'openai_moderation_enabled': True 'openai_moderation_enabled': True,
'netdata_root': None
} }
config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name'] config_required_vars = ['token_limit', 'concurrent_gens', 'mode', 'llm_middleware_name']

View File

@ -6,7 +6,7 @@ import redis as redis_pkg
import simplejson as json import simplejson as json
from flask_caching import Cache from flask_caching import Cache
from redis import Redis from redis import Redis
from redis.typing import FieldT, ExpiryT from redis.typing import ExpiryT, FieldT
flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'}) flask_cache = Cache(config={'CACHE_TYPE': 'RedisCache', 'CACHE_REDIS_URL': 'redis://localhost:6379/0', 'CACHE_KEY_PREFIX': 'local_llm_flask'})
@ -72,6 +72,20 @@ class RedisWrapper:
def sismember(self, key: str, value: str): def sismember(self, key: str, value: str):
return self.redis.sismember(self._key(key), value) return self.redis.sismember(self._key(key), value)
def lindex(
self, name: str, index: int
):
return self.redis.lindex(self._key(name), index)
def lrem(self, name: str, count: int, value: str):
return self.redis.lrem(self._key(name), count, value)
def rpush(self, name: str, *values: FieldT):
return self.redis.rpush(self._key(name), *values)
def llen(self, name: str):
return self.redis.llen(self._key(name))
def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None): def set_dict(self, key: Union[list, dict], dict_value, ex: Union[ExpiryT, None] = None):
return self.set(key, json.dumps(dict_value), ex=ex) return self.set(key, json.dumps(dict_value), ex=ex)

View File

@ -31,7 +31,7 @@ class OobaRequestHandler(RequestHandler):
msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.' msg = f'Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.'
backend_response = self.handle_error(msg) backend_response = self.handle_error(msg)
log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True) log_prompt(self.client_ip, self.token, self.request_json_body.get('prompt', ''), backend_response[0].data.decode('utf-8'), None, self.parameters, dict(self.request.headers), 429, self.request.url, is_error=True)
return backend_response[0], 429 # We only return the response from handle_error(), not the error code return backend_response[0], 200 # We only return the response from handle_error(), not the error code
def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]: def handle_error(self, error_msg: str, error_type: str = 'error') -> Tuple[flask.Response, int]:
disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true' disable_st_error_formatting = request.headers.get('LLM-ST-Errors', False) == 'true'

View File

@ -9,9 +9,6 @@ from redis import Redis
from llm_server import opts from llm_server import opts
from llm_server.llm.generator import generator from llm_server.llm.generator import generator
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
from llm_server.routes.stats import generation_elapsed, generation_elapsed_lock
redis.set_dict('processing_ips', {})
def increment_ip_count(client_ip: int, redis_key): def increment_ip_count(client_ip: int, redis_key):
@ -33,27 +30,20 @@ def decrement_ip_count(client_ip: int, redis_key):
class RedisPriorityQueue: class RedisPriorityQueue:
def __init__(self): def __init__(self):
self._index = 0
self._lock = threading.Lock()
self.redis = Redis(host='localhost', port=6379, db=15) self.redis = Redis(host='localhost', port=6379, db=15)
# Clear the DB
for key in self.redis.scan_iter('*'):
self.redis.delete(key)
self.pubsub = self.redis.pubsub() self.pubsub = self.redis.pubsub()
self.pubsub.subscribe('events') self.pubsub.subscribe('events')
def put(self, item, priority): def put(self, item, priority):
event = DataEvent() event = DataEvent()
# Check if the IP is already in the dictionary and if it has reached the limit # Check if the IP is already in the dictionary and if it has reached the limit
ip_count = self.redis.hget('queued_ip_count', item[1]) ip_count = self.redis.hget('queued_ip_count', item[1])
if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0: if ip_count and int(ip_count) >= opts.simultaneous_requests_per_ip and priority != 0:
print(f'Rejecting request from {item[1]} - {ip_count} requests in progress.')
return None # reject the request return None # reject the request
self.redis.zadd('queue', {json.dumps((self._index, item, event.event_id)): -priority})
self._index += 1 self.redis.zadd('queue', {json.dumps((item, event.event_id)): -priority})
# Increment the count for this IP
with self._lock:
self.increment_ip_count(item[1], 'queued_ip_count') self.increment_ip_count(item[1], 'queued_ip_count')
return event return event
@ -64,7 +54,6 @@ class RedisPriorityQueue:
item = json.loads(data[0][0]) item = json.loads(data[0][0])
client_ip = item[1][1] client_ip = item[1][1]
# Decrement the count for this IP # Decrement the count for this IP
with self._lock:
self.decrement_ip_count(client_ip, 'queued_ip_count') self.decrement_ip_count(client_ip, 'queued_ip_count')
return item return item
time.sleep(1) # wait for an item to be added to the queue time.sleep(1) # wait for an item to be added to the queue
@ -100,11 +89,9 @@ priority_queue = RedisPriorityQueue()
def worker(): def worker():
while True: while True:
index, (request_json_body, client_ip, token, parameters), event_id = priority_queue.get() (request_json_body, client_ip, token, parameters), event_id = priority_queue.get()
increment_ip_count(client_ip, 'processing_ips') increment_ip_count(client_ip, 'processing_ips')
# TODO: only increment if not valid SYSTEM__ token
redis.incr('active_gen_workers') redis.incr('active_gen_workers')
try: try:
@ -113,15 +100,12 @@ def worker():
end_time = time.time() end_time = time.time()
elapsed_time = end_time - start_time elapsed_time = end_time - start_time
with generation_elapsed_lock: redis.rpush('generation_elapsed', json.dumps((end_time, elapsed_time)))
generation_elapsed.append((end_time, elapsed_time))
event = DataEvent(event_id) event = DataEvent(event_id)
event.set((success, response, error_msg)) event.set((success, response, error_msg))
finally: finally:
decrement_ip_count(client_ip, 'processing_ips') decrement_ip_count(client_ip, 'processing_ips')
# TODO: only decrement if not valid SYSTEM__ token
redis.decr('active_gen_workers') redis.decr('active_gen_workers')

View File

@ -193,6 +193,10 @@ class RequestHandler:
return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers)) return (success, response, error_msg, elapsed_time), self.backend.handle_response(success, self.request, response_json_body, response_status_code, self.client_ip, self.token, prompt, elapsed_time, self.parameters, dict(self.request.headers))
def is_client_ratelimited(self) -> bool: def is_client_ratelimited(self) -> bool:
print('queued_ip_count', redis.get_dict('queued_ip_count'))
print('processing_ips', redis.get_dict('processing_ips'))
queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0) queued_ip_count = redis.get_dict('queued_ip_count').get(self.client_ip, 0) + redis.get_dict('processing_ips').get(self.client_ip, 0)
if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0: if queued_ip_count < self.token_simultaneous_ip or self.token_priority == 0:
return False return False

View File

@ -1,6 +1,6 @@
import time import time
from datetime import datetime from datetime import datetime
from threading import Lock, Thread from threading import Thread
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
@ -9,14 +9,6 @@ from llm_server.routes.cache import redis
server_start_time = datetime.now() server_start_time = datetime.now()
# TODO: have a background thread put the averages in a variable so we don't end up with massive arrays
# wait_in_queue_elapsed = []
# wait_in_queue_elapsed_lock = Lock()
generation_elapsed = []
generation_elapsed_lock = Lock()
# TODO: do I need this? # TODO: do I need this?
# def elapsed_times_cleanup(): # def elapsed_times_cleanup():
@ -30,8 +22,6 @@ generation_elapsed_lock = Lock()
def calculate_avg_gen_time(): def calculate_avg_gen_time():
# TODO: calculate the average from the database. Have this be set by an option in the config
# Get the average generation time from Redis # Get the average generation time from Redis
average_generation_time = redis.get('average_generation_time') average_generation_time = redis.get('average_generation_time')
if average_generation_time is None: if average_generation_time is None:
@ -40,30 +30,6 @@ def calculate_avg_gen_time():
return float(average_generation_time) return float(average_generation_time)
def process_avg_gen_time():
global generation_elapsed
while True:
with generation_elapsed_lock:
# Get the current time
current_time = time.time()
# Remove data older than 3 minutes
three_minutes_ago = current_time - 180
generation_elapsed[:] = [(end, elapsed) for end, elapsed in generation_elapsed if end >= three_minutes_ago]
# Get the data from the last minute
one_minute_ago = current_time - 60
recent_data = [elapsed for end, elapsed in generation_elapsed if end >= one_minute_ago]
# Calculate the average
if len(recent_data) == 0:
average_generation_time = 0
else:
average_generation_time = sum(recent_data) / len(recent_data)
redis.set('average_generation_time', average_generation_time)
time.sleep(5)
def get_total_proompts(): def get_total_proompts():
count = redis.get('proompts') count = redis.get('proompts')
if count is None: if count is None:

View File

@ -8,7 +8,7 @@ from llm_server.llm.info import get_running_model
from llm_server.netdata import get_power_states from llm_server.netdata import get_power_states
from llm_server.routes.cache import redis from llm_server.routes.cache import redis
from llm_server.routes.queue import priority_queue from llm_server.routes.queue import priority_queue
from llm_server.routes.stats import calculate_avg_gen_time, get_active_gen_workers, get_total_proompts, server_start_time from llm_server.routes.stats import get_active_gen_workers, get_total_proompts, server_start_time
def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers): def calculate_wait_time(gen_time_calc, proompters_in_queue, concurrent_gens, active_gen_workers):
@ -61,22 +61,8 @@ def generate_stats(regen: bool = False):
# This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM. # This is so wildly inaccurate it's disabled until I implement stats reporting into VLLM.
# estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0) # estimated_avg_tps = redis.get('estimated_avg_tps', float, default=0)
if opts.average_generation_time_mode == 'database':
average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0) average_generation_time = redis.get('average_generation_elapsed_sec', float, default=0)
estimated_wait_sec = calculate_wait_time(average_generation_time, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
# What to use in our math that calculates the wait time.
# We could use the average TPS but we don't know the exact TPS value, only
# the backend knows that. So, let's just stick with the elapsed time.
gen_time_calc = average_generation_time
estimated_wait_sec = calculate_wait_time(gen_time_calc, proompters_in_queue, opts.concurrent_gens, active_gen_workers)
elif opts.average_generation_time_mode == 'minute':
average_generation_time = calculate_avg_gen_time()
gen_time_calc = average_generation_time
estimated_wait_sec = ((gen_time_calc * proompters_in_queue) / opts.concurrent_gens) + (active_gen_workers * gen_time_calc)
else:
raise Exception
if opts.netdata_root: if opts.netdata_root:
netdata_stats = {} netdata_stats = {}
@ -100,7 +86,7 @@ def generate_stats(regen: bool = False):
}, },
'proompts_total': get_total_proompts() if opts.show_num_prompts else None, 'proompts_total': get_total_proompts() if opts.show_num_prompts else None,
'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None, 'uptime': int((datetime.now() - server_start_time).total_seconds()) if opts.show_uptime else None,
'average_generation_elapsed_sec': int(gen_time_calc), 'average_generation_elapsed_sec': int(average_generation_time),
# 'estimated_avg_tps': estimated_avg_tps, # 'estimated_avg_tps': estimated_avg_tps,
'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None, 'tokens_generated': sum_column('prompts', 'response_tokens') if opts.show_total_output_tokens else None,
'nvidia': netdata_stats 'nvidia': netdata_stats

View File

@ -1,3 +1,5 @@
from redis import Redis
try: try:
import gevent.monkey import gevent.monkey
@ -32,6 +34,8 @@ from llm_server.stream import init_socketio
# TODO: allow setting concurrent gens per-backend # TODO: allow setting concurrent gens per-backend
# TODO: set the max tokens to that of the lowest backend # TODO: set the max tokens to that of the lowest backend
# TODO: implement RRD backend loadbalancer option # TODO: implement RRD backend loadbalancer option
# TODO: have VLLM reject a request if it already has n == concurrent_gens running
# TODO: add a way to cancel VLLM gens. Maybe use websockets?
# Lower priority # Lower priority
# TODO: the processing stat showed -1 and I had to restart the server # TODO: the processing stat showed -1 and I had to restart the server
@ -62,7 +66,7 @@ from llm_server.llm.vllm.info import vllm_info
from llm_server.routes.cache import RedisWrapper, flask_cache from llm_server.routes.cache import RedisWrapper, flask_cache
from llm_server.llm import redis from llm_server.llm import redis
from llm_server.routes.queue import start_workers from llm_server.routes.queue import start_workers
from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers, process_avg_gen_time from llm_server.routes.stats import SemaphoreCheckerThread, get_active_gen_workers
from llm_server.routes.v1.generate_stats import generate_stats from llm_server.routes.v1.generate_stats import generate_stats
from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers from llm_server.threads import MainBackgroundThread, cache_stats, start_moderation_workers
@ -160,6 +164,12 @@ def pre_fork(server):
flushed_keys = redis.flush() flushed_keys = redis.flush()
print('Flushed', len(flushed_keys), 'keys from Redis.') print('Flushed', len(flushed_keys), 'keys from Redis.')
redis.set_dict('processing_ips', {})
redis.set_dict('queued_ip_count', {})
queue_redis = Redis(host='localhost', port=6379, db=15)
for key in queue_redis.scan_iter('*'):
queue_redis.delete(key)
redis.set('backend_mode', opts.mode) redis.set('backend_mode', opts.mode)
if config['http_host']: if config['http_host']:
http_host = re.sub(r'http(?:s)?://', '', config["http_host"]) http_host = re.sub(r'http(?:s)?://', '', config["http_host"])
@ -174,9 +184,6 @@ def pre_fork(server):
print(f'Started {opts.concurrent_gens} inference workers.') print(f'Started {opts.concurrent_gens} inference workers.')
start_moderation_workers(opts.openai_moderation_workers) start_moderation_workers(opts.openai_moderation_workers)
process_avg_gen_time_background_thread = Thread(target=process_avg_gen_time)
process_avg_gen_time_background_thread.daemon = True
process_avg_gen_time_background_thread.start()
MainBackgroundThread().start() MainBackgroundThread().start()
SemaphoreCheckerThread().start() SemaphoreCheckerThread().start()