local-llm-server/llm_server/workers/moderator.py

62 lines
1.9 KiB
Python

import json
import threading
import time
import traceback
import redis as redis_redis
from llm_server import opts
from llm_server.llm.openai.moderation import check_moderation_endpoint
from llm_server.logging import create_logger
redis_moderation = redis_redis.Redis()
def start_moderation_workers(num_workers):
i = 0
for _ in range(num_workers):
t = threading.Thread(target=moderation_worker)
t.daemon = True
t.start()
i += 1
# TODO: don't use UUID tags to identify items. Use native redis
def get_results(tag, num_tasks):
tag = str(tag) # Cast a UUID4 to a string.
flagged_categories = set()
num_results = 0
start_time = time.time()
while num_results < num_tasks:
result = redis_moderation.blpop(['queue:flagged_categories'], timeout=opts.openai_moderation_timeout)
if result is None:
break # Timeout occurred, break the loop.
result_tag, categories = json.loads(result[1])
if result_tag == tag:
if categories:
for item in categories:
flagged_categories.add(item)
num_results += 1
if time.time() - start_time > opts.openai_moderation_timeout:
logger.warning('Timed out waiting for result from moderator')
break
return list(flagged_categories)
def moderation_worker():
logger = create_logger('moderator')
while True:
result = redis_moderation.blpop(['queue:msgs_to_check'])
try:
msg, tag = json.loads(result[1])
_, categories = check_moderation_endpoint(msg)
redis_moderation.rpush('queue:flagged_categories', json.dumps((tag, categories)))
except:
logger.error(traceback.format_exc())
continue
def add_moderation_task(msg, tag):
redis_moderation.rpush('queue:msgs_to_check', json.dumps((msg, str(tag))))