121 lines
4.7 KiB
Python
121 lines
4.7 KiB
Python
import json
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from threading import Thread
|
|
|
|
import redis as redis_redis
|
|
|
|
from llm_server import opts
|
|
from llm_server.database.database import weighted_average_column_for_model
|
|
from llm_server.llm.info import get_running_model
|
|
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
|
from llm_server.routes.cache import redis
|
|
from llm_server.routes.v1.generate_stats import generate_stats
|
|
|
|
|
|
class MainBackgroundThread(Thread):
|
|
backend_online = False
|
|
|
|
# TODO: do I really need to put everything in Redis?
|
|
# TODO: call generate_stats() every minute, cache the results, put results in a DB table, then have other parts of code call this cache
|
|
|
|
def __init__(self):
|
|
Thread.__init__(self)
|
|
self.daemon = True
|
|
redis.set('average_generation_elapsed_sec', 0)
|
|
redis.set('estimated_avg_tps', 0)
|
|
redis.set('average_output_tokens', 0)
|
|
redis.set('backend_online', 0)
|
|
redis.set_dict('backend_info', {})
|
|
|
|
def run(self):
|
|
while True:
|
|
# TODO: unify this
|
|
if opts.mode == 'oobabooga':
|
|
running_model, err = get_running_model()
|
|
if err:
|
|
print(err)
|
|
redis.set('backend_online', 0)
|
|
else:
|
|
redis.set('running_model', running_model)
|
|
redis.set('backend_online', 1)
|
|
elif opts.mode == 'vllm':
|
|
running_model, err = get_running_model()
|
|
if err:
|
|
print(err)
|
|
redis.set('backend_online', 0)
|
|
else:
|
|
redis.set('running_model', running_model)
|
|
redis.set('backend_online', 1)
|
|
else:
|
|
raise Exception
|
|
|
|
# exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0
|
|
# was entered into the column. The new code enters null instead but we need to be backwards compatible for now.
|
|
average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
|
if average_generation_elapsed_sec: # returns None on exception
|
|
redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec)
|
|
|
|
# overall = average_column_for_model('prompts', 'generation_time', opts.running_model)
|
|
# print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}')
|
|
|
|
average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0
|
|
if average_generation_elapsed_sec:
|
|
redis.set('average_output_tokens', average_output_tokens)
|
|
|
|
# overall = average_column_for_model('prompts', 'response_tokens', opts.running_model)
|
|
# print(f'Weighted: {average_output_tokens}, overall: {overall}')
|
|
|
|
estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero
|
|
redis.set('estimated_avg_tps', estimated_avg_tps)
|
|
time.sleep(60)
|
|
|
|
|
|
def cache_stats():
|
|
while True:
|
|
generate_stats(regen=True)
|
|
time.sleep(5)
|
|
|
|
|
|
redis_moderation = redis_redis.Redis()
|
|
|
|
|
|
def start_moderation_workers(num_workers):
|
|
for _ in range(num_workers):
|
|
t = threading.Thread(target=moderation_worker)
|
|
t.daemon = True
|
|
t.start()
|
|
|
|
|
|
def moderation_worker():
|
|
while True:
|
|
result = redis_moderation.blpop('queue:msgs_to_check')
|
|
try:
|
|
msg, tag = json.loads(result[1])
|
|
_, categories = check_moderation_endpoint(msg)
|
|
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
|
except:
|
|
print(result)
|
|
traceback.print_exc()
|
|
continue
|
|
|
|
|
|
def add_moderation_task(msg, tag):
|
|
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
|
|
|
|
|
|
def get_results(tag, num_tasks):
|
|
tag = str(tag) # Required for comparison with Redis results.
|
|
flagged_categories = set()
|
|
num_results = 0
|
|
while num_results < num_tasks:
|
|
result = redis_moderation.blpop('queue:flagged_categories')
|
|
result_tag, categories = json.loads(result[1])
|
|
if result_tag == tag:
|
|
if categories:
|
|
for item in categories:
|
|
flagged_categories.add(item)
|
|
num_results += 1
|
|
return list(flagged_categories)
|