62 lines
1.9 KiB
Python
62 lines
1.9 KiB
Python
import json
|
|
import threading
|
|
import time
|
|
import traceback
|
|
|
|
import redis as redis_redis
|
|
|
|
from llm_server.config.global_config import GlobalConfig
|
|
from llm_server.llm.openai.moderation import check_moderation_endpoint
|
|
from llm_server.logging import create_logger
|
|
|
|
redis_moderation = redis_redis.Redis()
|
|
|
|
|
|
def start_moderation_workers(num_workers):
|
|
i = 0
|
|
for _ in range(num_workers):
|
|
t = threading.Thread(target=moderation_worker)
|
|
t.daemon = True
|
|
t.start()
|
|
i += 1
|
|
|
|
|
|
# TODO: don't use UUID tags to identify items. Use native redis
|
|
|
|
def get_results(tag, num_tasks):
|
|
tag = str(tag) # Cast a UUID4 to a string.
|
|
flagged_categories = set()
|
|
num_results = 0
|
|
start_time = time.time()
|
|
while num_results < num_tasks:
|
|
result = redis_moderation.blpop(['queue:flagged_categories'], timeout=GlobalConfig.get().openai_moderation_timeout)
|
|
if result is None:
|
|
break # Timeout occurred, break the loop.
|
|
result_tag, categories = json.loads(result[1])
|
|
if result_tag == tag:
|
|
if categories:
|
|
for item in categories:
|
|
flagged_categories.add(item)
|
|
num_results += 1
|
|
if time.time() - start_time > GlobalConfig.get().openai_moderation_timeout:
|
|
logger.warning('Timed out waiting for result from moderator')
|
|
break
|
|
return list(flagged_categories)
|
|
|
|
|
|
def moderation_worker():
|
|
logger = create_logger('moderator')
|
|
while True:
|
|
result = redis_moderation.blpop(['queue:msgs_to_check'])
|
|
try:
|
|
msg, tag = json.loads(result[1])
|
|
_, categories = check_moderation_endpoint(msg)
|
|
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
|
|
except:
|
|
logger.error(traceback.format_exc())
|
|
continue
|
|
|
|
|
|
def add_moderation_task(msg, tag):
|
|
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))
|