import json import threading import time import traceback from threading import Thread import redis as redis_redis from llm_server import opts from llm_server.database.database import weighted_average_column_for_model from llm_server.llm.info import get_running_model from llm_server.llm.openai.moderation import check_moderation_endpoint from llm_server.routes.cache import redis from llm_server.routes.v1.generate_stats import generate_stats class MainBackgroundThread(Thread): backend_online = False # TODO: do I really need to put everything in Redis? # TODO: call generate_stats() every minute, cache the results, put results in a DB table, then have other parts of code call this cache def __init__(self): Thread.__init__(self) self.daemon = True redis.set('average_generation_elapsed_sec', 0) redis.set('estimated_avg_tps', 0) redis.set('average_output_tokens', 0) redis.set('backend_online', 0) redis.set_dict('backend_info', {}) def run(self): while True: # TODO: unify this if opts.mode == 'oobabooga': running_model, err = get_running_model() if err: print(err) redis.set('backend_online', 0) else: redis.set('running_model', running_model) redis.set('backend_online', 1) elif opts.mode == 'vllm': running_model, err = get_running_model() if err: print(err) redis.set('backend_online', 0) else: redis.set('running_model', running_model) redis.set('backend_online', 1) else: raise Exception # exclude_zeros=True filters out rows where an error message was returned. Previously, if there was an error, 0 # was entered into the column. The new code enters null instead but we need to be backwards compatible for now. average_generation_elapsed_sec = weighted_average_column_for_model('prompts', 'generation_time', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 if average_generation_elapsed_sec: # returns None on exception redis.set('average_generation_elapsed_sec', average_generation_elapsed_sec) # overall = average_column_for_model('prompts', 'generation_time', opts.running_model) # print(f'Weighted: {average_generation_elapsed_sec}, overall: {overall}') average_output_tokens = weighted_average_column_for_model('prompts', 'response_tokens', running_model, opts.mode, opts.backend_url, exclude_zeros=True, include_system_tokens=opts.include_system_tokens_in_stats) or 0 if average_generation_elapsed_sec: redis.set('average_output_tokens', average_output_tokens) # overall = average_column_for_model('prompts', 'response_tokens', opts.running_model) # print(f'Weighted: {average_output_tokens}, overall: {overall}') estimated_avg_tps = round(average_output_tokens / average_generation_elapsed_sec, 2) if average_generation_elapsed_sec > 0 else 0 # Avoid division by zero redis.set('estimated_avg_tps', estimated_avg_tps) time.sleep(60) def cache_stats(): while True: generate_stats(regen=True) time.sleep(5) redis_moderation = redis_redis.Redis() def start_moderation_workers(num_workers): for _ in range(num_workers): t = threading.Thread(target=moderation_worker) t.daemon = True t.start() def moderation_worker(): while True: result = redis_moderation.blpop('queue:msgs_to_check') try: msg, tag = json.loads(result[1]) _, categories = check_moderation_endpoint(msg) redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories))) except: print(result) traceback.print_exc() continue def add_moderation_task(msg, tag): redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag)))) def get_results(tag, num_tasks): tag = str(tag) # Required for comparison with Redis results. flagged_categories = set() num_results = 0 while num_results < num_tasks: result = redis_moderation.blpop('queue:flagged_categories') result_tag, categories = json.loads(result[1]) if result_tag == tag: if categories: for item in categories: flagged_categories.add(item) num_results += 1 return list(flagged_categories)