try to prevent "### XXX" responses on openai
This commit is contained in:
parent
11e84db59c
commit
bbdb9c9d55
|
@ -38,7 +38,7 @@ def openai_chat_completions():
|
||||||
request_valid, invalid_response = handler.validate_request()
|
request_valid, invalid_response = handler.validate_request()
|
||||||
if not request_valid:
|
if not request_valid:
|
||||||
# TODO: simulate OAI here
|
# TODO: simulate OAI here
|
||||||
raise Exception
|
raise Exception('TODO: simulate OAI here')
|
||||||
else:
|
else:
|
||||||
handler.prompt = handler.transform_messages_to_prompt()
|
handler.prompt = handler.transform_messages_to_prompt()
|
||||||
msg_to_backend = {
|
msg_to_backend = {
|
||||||
|
|
|
@ -19,6 +19,9 @@ from llm_server.routes.request_handler import RequestHandler
|
||||||
|
|
||||||
tokenizer = tiktoken.get_encoding("cl100k_base")
|
tokenizer = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
ANTI_RESPONSE_RE = re.compile(r'^### (.*?)(?:\:)?\s') # Match a "### XXX" line.
|
||||||
|
ANTI_CONTINUATION_RE = re.compile(r'(.*?### .*?(?:\:)?(.|\n)*)') # Match everything after a "### XXX" line.
|
||||||
|
|
||||||
|
|
||||||
class OpenAIRequestHandler(RequestHandler):
|
class OpenAIRequestHandler(RequestHandler):
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
@ -145,9 +148,11 @@ def build_openai_response(prompt, response, model=None):
|
||||||
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
prompt = re.sub(r'\n$', '', x[-1].strip(' '))
|
||||||
|
|
||||||
# Make sure the bot doesn't put any other instructions in its response
|
# Make sure the bot doesn't put any other instructions in its response
|
||||||
y = response.split('\n### ')
|
# y = response.split('\n### ')
|
||||||
if len(x) > 1:
|
# if len(y) > 1:
|
||||||
response = re.sub(r'\n$', '', y[0].strip(' '))
|
# response = re.sub(r'\n$', '', y[0].strip(' '))
|
||||||
|
response = re.sub(ANTI_RESPONSE_RE, '', response)
|
||||||
|
response = re.sub(ANTI_CONTINUATION_RE, '', response)
|
||||||
|
|
||||||
# TODO: async/await
|
# TODO: async/await
|
||||||
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
prompt_tokens = llm_server.llm.get_token_count(prompt)
|
||||||
|
|
|
@ -21,6 +21,7 @@ from llm_server.routes.server_error import handle_server_error
|
||||||
# TODO: option to trim context in openai mode so that we silently fit the model's context
|
# TODO: option to trim context in openai mode so that we silently fit the model's context
|
||||||
# TODO: validate system tokens before excluding them
|
# TODO: validate system tokens before excluding them
|
||||||
# TODO: make sure prompts are logged even when the user cancels generation
|
# TODO: make sure prompts are logged even when the user cancels generation
|
||||||
|
# TODO: add some sort of loadbalancer to send requests to a group of backends
|
||||||
|
|
||||||
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
# TODO: make sure log_prompt() is used everywhere, including errors and invalid requests
|
||||||
# TODO: unify logging thread in a function and use async/await instead
|
# TODO: unify logging thread in a function and use async/await instead
|
||||||
|
|
Reference in New Issue