import json import threading import time import traceback import redis as redis_redis from llm_server.config.global_config import GlobalConfig from llm_server.llm.openai.moderation import check_moderation_endpoint from llm_server.logging import create_logger redis_moderation = redis_redis.Redis() def start_moderation_workers(num_workers): i = 0 for _ in range(num_workers): t = threading.Thread(target=moderation_worker) t.daemon = True t.start() i += 1 # TODO: don't use UUID tags to identify items. Use native redis def get_results(tag, num_tasks): tag = str(tag) # Cast a UUID4 to a string. flagged_categories = set() num_results = 0 start_time = time.time() while num_results < num_tasks: result = redis_moderation.blpop(['queue:flagged_categories'], timeout=GlobalConfig.get().openai_moderation_timeout) if result is None: break # Timeout occurred, break the loop. result_tag, categories = json.loads(result[1]) if result_tag == tag: if categories: for item in categories: flagged_categories.add(item) num_results += 1 if time.time() - start_time > GlobalConfig.get().openai_moderation_timeout: logger.warning('Timed out waiting for result from moderator') break return list(flagged_categories) def moderation_worker(): logger = create_logger('moderator') while True: result = redis_moderation.blpop(['queue:msgs_to_check']) try: msg, tag = json.loads(result[1]) _, categories = check_moderation_endpoint(msg) redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories))) except: logger.error(traceback.format_exc()) continue def add_moderation_task(msg, tag): redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))