2023-09-14 15:07:17 -06:00
import json
2023-09-12 16:40:09 -06:00
import re
import time
2023-09-14 15:07:17 -06:00
import traceback
2023-09-14 14:05:50 -06:00
from typing import Tuple
2023-09-12 16:40:09 -06:00
from uuid import uuid4
2023-09-14 14:05:50 -06:00
import flask
2023-09-14 15:07:17 -06:00
import requests
2023-09-12 16:40:09 -06:00
import tiktoken
from flask import jsonify
from llm_server import opts
from llm_server . database import log_prompt
from llm_server . routes . helpers . client import format_sillytavern_err
from llm_server . routes . request_handler import RequestHandler
tokenizer = tiktoken . get_encoding ( " cl100k_base " )
class OpenAIRequestHandler ( RequestHandler ) :
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
self . prompt = None
2023-09-14 14:05:50 -06:00
def handle_request ( self ) - > Tuple [ flask . Response , int ] :
2023-09-12 16:40:09 -06:00
if self . used :
raise Exception
2023-09-14 14:05:50 -06:00
request_valid , invalid_response = self . validate_request ( )
if not request_valid :
return invalid_response
2023-09-12 16:40:09 -06:00
2023-09-14 14:05:50 -06:00
self . prompt = self . transform_messages_to_prompt ( )
2023-09-12 16:40:09 -06:00
# Reconstruct the request JSON with the validated parameters and prompt.
self . parameters [ ' stop ' ] . extend ( [ ' \n ### INSTRUCTION ' , ' \n ### USER ' , ' \n ### ASSISTANT ' , ' \n ### RESPONSE ' ] )
llm_request = { * * self . parameters , ' prompt ' : self . prompt }
2023-09-14 14:05:50 -06:00
_ , ( backend_response , backend_response_status_code ) = self . generate_response ( llm_request )
2023-09-14 15:07:17 -06:00
if opts . openai_api_key :
try :
flagged = check_moderation_endpoint ( self . request . json [ ' messages ' ] [ - 1 ] [ ' content ' ] )
if flagged :
mod_msg = f " The user ' s message does not comply with { opts . llm_middleware_name } policies. Offending categories: { json . dumps ( flagged [ ' categories ' ] ) } "
self . request . json [ ' messages ' ] . insert ( ( len ( self . request . json [ ' messages ' ] ) ) , { ' role ' : ' system ' , ' content ' : mod_msg } )
self . prompt = self . transform_messages_to_prompt ( )
# print(json.dumps(self.request.json['messages'], indent=4))
except Exception as e :
print ( f ' OpenAI moderation endpoint failed: ' , f ' { e . __class__ . __name__ } : { e } ' )
print ( traceback . format_exc ( ) )
2023-09-14 14:05:50 -06:00
return build_openai_response ( self . prompt , backend_response . json [ ' results ' ] [ 0 ] [ ' text ' ] ) , backend_response_status_code
2023-09-12 16:40:09 -06:00
def handle_ratelimited ( self ) :
backend_response = format_sillytavern_err ( f ' Ratelimited: you are only allowed to have { opts . simultaneous_requests_per_ip } simultaneous requests at a time. Please complete your other requests before sending another. ' , ' error ' )
2023-09-13 11:51:46 -06:00
log_prompt ( ip = self . client_ip , token = self . token , prompt = self . request_json_body . get ( ' prompt ' , ' ' ) , response = backend_response , gen_time = None , parameters = self . parameters , headers = dict ( self . request . headers ) , backend_response_code = 429 , request_url = self . request . url , is_error = True )
2023-09-12 16:40:09 -06:00
return build_openai_response ( self . prompt , backend_response ) , 200
def transform_messages_to_prompt ( self ) :
try :
prompt = f ' ### INSTRUCTION: { opts . openai_system_prompt } '
for msg in self . request . json [ ' messages ' ] :
if not msg . get ( ' content ' ) or not msg . get ( ' role ' ) :
return False
if msg [ ' role ' ] == ' system ' :
prompt + = f ' ### INSTRUCTION: { msg [ " content " ] } \n \n '
elif msg [ ' role ' ] == ' user ' :
prompt + = f ' ### USER: { msg [ " content " ] } \n \n '
elif msg [ ' role ' ] == ' assistant ' :
prompt + = f ' ### ASSISTANT: { msg [ " content " ] } \n \n '
else :
return False
2023-09-14 15:07:17 -06:00
except Exception as e :
# TODO: use logging
print ( f ' Failed to transform OpenAI to prompt: ' , f ' { e . __class__ . __name__ } : { e } ' )
print ( traceback . format_exc ( ) )
return ' '
2023-09-12 16:40:09 -06:00
prompt = prompt . strip ( ' ' ) . strip ( ' \n ' ) . strip ( ' \n \n ' ) # TODO: this is really lazy
prompt + = ' \n \n ### RESPONSE: '
return prompt
2023-09-14 15:07:17 -06:00
def check_moderation_endpoint ( prompt : str ) :
headers = {
' Content-Type ' : ' application/json ' ,
' Authorization ' : f " Bearer { opts . openai_api_key } " ,
}
response = requests . post ( ' https://api.openai.com/v1/moderations ' , headers = headers , json = { " input " : prompt } ) . json ( )
offending_categories = [ ]
for k , v in response [ ' results ' ] [ 0 ] [ ' categories ' ] . items ( ) :
if v :
offending_categories . append ( k )
return { ' flagged ' : response [ ' results ' ] [ 0 ] [ ' flagged ' ] , ' categories ' : offending_categories }
2023-09-12 16:40:09 -06:00
def build_openai_response ( prompt , response ) :
# Seperate the user's prompt from the context
x = prompt . split ( ' ### USER: ' )
if len ( x ) > 1 :
prompt = re . sub ( r ' \ n$ ' , ' ' , x [ - 1 ] . strip ( ' ' ) )
# Make sure the bot doesn't put any other instructions in its response
y = response . split ( ' \n ### ' )
if len ( x ) > 1 :
response = re . sub ( r ' \ n$ ' , ' ' , y [ 0 ] . strip ( ' ' ) )
prompt_tokens = len ( tokenizer . encode ( prompt ) )
response_tokens = len ( tokenizer . encode ( response ) )
return jsonify ( {
" id " : f " chatcmpl- { uuid4 ( ) } " ,
" object " : " chat.completion " ,
" created " : int ( time . time ( ) ) ,
" model " : opts . running_model ,
" choices " : [ {
" index " : 0 ,
" message " : {
" role " : " assistant " ,
" content " : response ,
} ,
" finish_reason " : " stop "
} ] ,
" usage " : {
" prompt_tokens " : prompt_tokens ,
" completion_tokens " : response_tokens ,
" total_tokens " : prompt_tokens + response_tokens
}
} )