From 7434ae1b5b56f25a519dee6fb1d90f3f811fd45b Mon Sep 17 00:00:00 2001 From: Cyberes Date: Sun, 17 Sep 2023 17:40:05 -0600 Subject: [PATCH] openai: improve moderation checking --- llm_server/routes/openai_request_handler.py | 26 +++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/llm_server/routes/openai_request_handler.py b/llm_server/routes/openai_request_handler.py index d3292ed..c2d5401 100644 --- a/llm_server/routes/openai_request_handler.py +++ b/llm_server/routes/openai_request_handler.py @@ -35,9 +35,27 @@ class OpenAIRequestHandler(RequestHandler): if opts.openai_api_key: try: - flagged = check_moderation_endpoint(self.request.json['messages'][-1]['content']) - if flagged['flagged'] and len(flagged['categories']): - mod_msg = f"The user's message does not comply with {opts.llm_middleware_name} policies. Offending categories: {json.dumps(flagged['categories'])}. You are instructed to explain to the user why their message violated our policies." + # Gather the last message from the user and all preceeding system messages + msg_l = self.request.json['messages'].copy() + msg_l.reverse() + msgs_to_check = [] + for msg in msg_l: + if msg['role'] == 'system': + msgs_to_check.append(msg['content']) + elif msg['role'] == 'user': + msgs_to_check.append(msg['content']) + break + + flagged = False + flagged_categories = [] + for msg in msgs_to_check: + flagged, categories = check_moderation_endpoint(msg) + flagged_categories.extend(categories) + if flagged: + break + + if flagged and len(flagged_categories): + mod_msg = f"The user's message does not comply with {opts.llm_middleware_name} policies. Offending categories: {json.dumps(flagged_categories)}. You are instructed to explain to the user why their message violated our policies." self.request.json['messages'].insert((len(self.request.json['messages'])), {'role': 'system', 'content': mod_msg}) self.prompt = self.transform_messages_to_prompt() # print(json.dumps(self.request.json['messages'], indent=4)) @@ -97,7 +115,7 @@ def check_moderation_endpoint(prompt: str): for k, v in response['results'][0]['categories'].items(): if v: offending_categories.append(k) - return {'flagged': response['results'][0]['flagged'], 'categories': offending_categories} + return response['results'][0]['flagged'], offending_categories def build_openai_response(prompt, response):