openai: improve moderation checking
This commit is contained in:
parent
354ad8192d
commit
7434ae1b5b
|
@ -35,9 +35,27 @@ class OpenAIRequestHandler(RequestHandler):
|
||||||
|
|
||||||
if opts.openai_api_key:
|
if opts.openai_api_key:
|
||||||
try:
|
try:
|
||||||
flagged = check_moderation_endpoint(self.request.json['messages'][-1]['content'])
|
# Gather the last message from the user and all preceeding system messages
|
||||||
if flagged['flagged'] and len(flagged['categories']):
|
msg_l = self.request.json['messages'].copy()
|
||||||
mod_msg = f"The user's message does not comply with {opts.llm_middleware_name} policies. Offending categories: {json.dumps(flagged['categories'])}. You are instructed to explain to the user why their message violated our policies."
|
msg_l.reverse()
|
||||||
|
msgs_to_check = []
|
||||||
|
for msg in msg_l:
|
||||||
|
if msg['role'] == 'system':
|
||||||
|
msgs_to_check.append(msg['content'])
|
||||||
|
elif msg['role'] == 'user':
|
||||||
|
msgs_to_check.append(msg['content'])
|
||||||
|
break
|
||||||
|
|
||||||
|
flagged = False
|
||||||
|
flagged_categories = []
|
||||||
|
for msg in msgs_to_check:
|
||||||
|
flagged, categories = check_moderation_endpoint(msg)
|
||||||
|
flagged_categories.extend(categories)
|
||||||
|
if flagged:
|
||||||
|
break
|
||||||
|
|
||||||
|
if flagged and len(flagged_categories):
|
||||||
|
mod_msg = f"The user's message does not comply with {opts.llm_middleware_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to explain to the user why their message violated our policies."
|
||||||
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
|
self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg})
|
||||||
self.prompt = self.transform_messages_to_prompt()
|
self.prompt = self.transform_messages_to_prompt()
|
||||||
# print(json.dumps(self.request.json['messages'], indent=4))
|
# print(json.dumps(self.request.json['messages'], indent=4))
|
||||||
|
@ -97,7 +115,7 @@ def check_moderation_endpoint(prompt: str):
|
||||||
for k, v in response['results'][0]['categories'].items():
|
for k, v in response['results'][0]['categories'].items():
|
||||||
if v:
|
if v:
|
||||||
offending_categories.append(k)
|
offending_categories.append(k)
|
||||||
return {'flagged': response['results'][0]['flagged'], 'categories': offending_categories}
|
return response['results'][0]['flagged'], offending_categories
|
||||||
|
|
||||||
|
|
||||||
def build_openai_response(prompt, response):
|
def build_openai_response(prompt, response):
|
||||||
|
|
Loading…
Reference in New Issue