try to prevent "### XXX" responses on openai

This commit is contained in:
Cyberes 2023-09-25 23:14:35 -06:00
parent 11e84db59c
commit bbdb9c9d55
3 changed files with 10 additions and 4 deletions

View File

@ -38,7 +38,7 @@ def openai_chat_completions():
request_valid, invalid_response = handler.validate_request()
if not request_valid:
# TODO: simulate OAI here
raise Exception
raise Exception('TODO: simulate OAI here')
else:
handler.prompt = handler.transform_messages_to_prompt()
msg_to_backend = {

View File

@ -19,6 +19,9 @@ from llm_server.routes.request_handler import RequestHandler
tokenizer = tiktoken.get_encoding("cl100k_base")
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
class OpenAIRequestHandler(RequestHandler):
def __init__(self, *args, **kwargs):
@ -145,9 +148,11 @@ def build_openai_response(prompt, response, model=None):
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
# Make sure the bot doesn't put any other instructions in its response
y = response.split('\n### ')
if len(x) > 1:
response = re.sub(r'\n$', '', y[0].strip(' '))
# y = response.split('\n### ')
# if len(y) > 1:
# response = re.sub(r'\n$', '', y[0].strip(' '))
response = re.sub(ANTI_RESPONSE_RE, '', response)
response = re.sub(ANTI_CONTINUATION_RE, '', response)
# TODO: async/await
prompt_tokens = llm_server.llm.get_token_count(prompt)

View File

@ -21,6 +21,7 @@ from llm_server.routes.server_error import handle_server_error
# TODO: option to trim context in openai mode so that we silently fit the model's context
# TODO: validate system tokens before excluding them
# TODO: make sure prompts are logged even when the user cancels generation
# TODO: add some sort of loadbalancer to send requests to a group of backends
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
# TODO: unify logging thread in a function and use async/await instead