import json import threading import time import traceback import redis as redis_redis from llm_server import opts from llm_server.llm.openai.moderation import check_moderation_endpoint redis_moderation = redis_redis.Redis() def start_moderation_workers(num_workers): i = 0 for _ in range(num_workers): t = threading.Thread(target=moderation_worker) t.daemon = True t.start() i += 1 print(f'Started {i} moderation workers.') # TODO: don't use UUID tags to identify items. Use native redis def get_results(tag, num_tasks): tag = str(tag) # Cast a UUID4 to a string. flagged_categories = set() num_results = 0 start_time = time.time() while num_results < num_tasks: result = redis_moderation.blpop(['queue:flagged_categories'], timeout=opts.openai_moderation_timeout) if result is None: break # Timeout occurred, break the loop. result_tag, categories = json.loads(result[1]) if result_tag == tag: if categories: for item in categories: flagged_categories.add(item) num_results += 1 if time.time() - start_time > opts.openai_moderation_timeout: print('----> Timed out waiting for result from moderator.') break return list(flagged_categories) def moderation_worker(): while True: result = redis_moderation.blpop(['queue:msgs_to_check']) try: msg, tag = json.loads(result[1]) _, categories = check_moderation_endpoint(msg) redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories))) except: traceback.print_exc() continue def add_moderation_task(msg, tag): redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))