openai: improve moderation checking

This commit is contained in:
Cyberes 2023-09-17 17:40:05 -06:00
parent 354ad8192d
commit 7434ae1b5b
1 changed files with 22 additions and 4 deletions

View File

@ -35,9 +35,27 @@ class OpenAIRequestHandler(RequestHandler):
if opts.openai_api_key:
try:
flagged = check_moderation_endpoint(self.request.json['messages'][-1]['content'])
if flagged['flagged'] and len(flagged['categories']):
mod_msg = f"The user's message does not comply with {opts.llm_middleware_name} policies. Offending categories: {json.dumps(flagged['categories'])}. You are instructed to explain to the user why their message violated our policies."
# Gather the last message from the user and all preceeding system messages
msg_l = self.request.json['messages'].copy()
msg_l.reverse()
msgs_to_check = []
for msg in msg_l:
if msg['role'] == 'system':
msgs_to_check.append(msg['content'])
elif msg['role'] == 'user':
msgs_to_check.append(msg['content'])
break
flagged = False
flagged_categories = []
for msg in msgs_to_check:
flagged, categories = check_moderation_endpoint(msg)
flagged_categories.extend(categories)
if flagged:
break
if flagged and len(flagged_categories):
mod_msg = f"The user's message does not comply with {opts.llm_middleware_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to explain to the user why their message violated our policies."
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
self.prompt = self.transform_messages_to_prompt()
# print(json.dumps(self.request.json['messages'], indent=4))
@ -97,7 +115,7 @@ def check_moderation_endpoint(prompt: str):
for k, v in response['results'][0]['categories'].items():
if v:
offending_categories.append(k)
return {'flagged': response['results'][0]['flagged'], 'categories': offending_categories}
return response['results'][0]['flagged'], offending_categories
def build_openai_response(prompt, response):