2023-09-14 15:07:17 -06:00
import json
import traceback
2023-09-14 14:05:50 -06:00
from typing import Tuple
2023-09-26 22:09:11 -06:00
from uuid import uuid4
2023-09-12 16:40:09 -06:00
2023-09-14 14:05:50 -06:00
import flask
2023-09-12 16:40:09 -06:00
from flask import jsonify
from llm_server import opts
2023-09-27 14:48:47 -06:00
from llm_server . database . database import is_api_key_moderated
2023-09-26 22:09:11 -06:00
from llm_server . llm . openai . transform import build_openai_response , transform_messages_to_prompt , trim_prompt_to_fit
2023-09-12 16:40:09 -06:00
from llm_server . routes . request_handler import RequestHandler
2023-09-27 23:36:44 -06:00
from llm_server . workers . moderator import add_moderation_task , get_results
2023-09-25 23:14:35 -06:00
2023-09-12 16:40:09 -06:00
class OpenAIRequestHandler ( RequestHandler ) :
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
self . prompt = None
2023-09-14 14:05:50 -06:00
def handle_request ( self ) - > Tuple [ flask . Response , int ] :
2023-09-24 21:45:30 -06:00
assert not self . used
2023-09-12 16:40:09 -06:00
2023-09-26 22:09:11 -06:00
if opts . openai_silent_trim :
oai_messages = trim_prompt_to_fit ( self . request . json [ ' messages ' ] , opts . context_size )
else :
oai_messages = self . request . json [ ' messages ' ]
self . prompt = transform_messages_to_prompt ( oai_messages )
2023-09-14 14:05:50 -06:00
request_valid , invalid_response = self . validate_request ( )
if not request_valid :
return invalid_response
2023-09-12 16:40:09 -06:00
2023-09-26 22:09:11 -06:00
if opts . openai_api_key and is_api_key_moderated ( self . token ) :
2023-09-14 15:07:17 -06:00
try :
2023-09-17 17:40:05 -06:00
# Gather the last message from the user and all preceeding system messages
msg_l = self . request . json [ ' messages ' ] . copy ( )
msg_l . reverse ( )
2023-09-26 22:09:11 -06:00
tag = uuid4 ( )
num_to_check = min ( len ( msg_l ) , opts . openai_moderation_scan_last_n )
for i in range ( num_to_check ) :
add_moderation_task ( msg_l [ i ] [ ' content ' ] , tag )
flagged_categories = get_results ( tag , num_to_check )
if len ( flagged_categories ) :
mod_msg = f " The user ' s message does not comply with { opts . openai_org_name } policies. Offending categories: { json . dumps ( flagged_categories ) } . You are instructed to creatively adhere to these policies. "
2023-09-14 15:07:17 -06:00
self . request . json [ ' messages ' ] . insert ( ( len ( self . request . json [ ' messages ' ] ) ) , { ' role ' : ' system ' , ' content ' : mod_msg } )
2023-09-26 22:09:11 -06:00
self . prompt = transform_messages_to_prompt ( self . request . json [ ' messages ' ] )
2023-09-14 15:07:17 -06:00
except Exception as e :
print ( f ' OpenAI moderation endpoint failed: ' , f ' { e . __class__ . __name__ } : { e } ' )
print ( traceback . format_exc ( ) )
2023-09-14 15:14:59 -06:00
# Reconstruct the request JSON with the validated parameters and prompt.
self . parameters [ ' stop ' ] . extend ( [ ' \n ### INSTRUCTION ' , ' \n ### USER ' , ' \n ### ASSISTANT ' , ' \n ### RESPONSE ' ] )
2023-09-25 22:01:57 -06:00
if opts . openai_force_no_hashes :
self . parameters [ ' stop ' ] . append ( ' ### ' )
2023-09-25 22:32:48 -06:00
if opts . mode == ' vllm ' and self . request_json_body . get ( ' top_p ' ) == 0 :
self . request_json_body [ ' top_p ' ] = 0.01
2023-09-14 15:14:59 -06:00
llm_request = { * * self . parameters , ' prompt ' : self . prompt }
2023-09-14 17:38:20 -06:00
( success , _ , _ , _ ) , ( backend_response , backend_response_status_code ) = self . generate_response ( llm_request )
2023-09-24 21:45:30 -06:00
2023-09-25 09:32:23 -06:00
model = self . request_json_body . get ( ' model ' )
2023-09-14 17:38:20 -06:00
if success :
2023-09-25 12:30:40 -06:00
return build_openai_response ( self . prompt , backend_response . json [ ' results ' ] [ 0 ] [ ' text ' ] , model = model ) , backend_response_status_code
2023-09-14 17:38:20 -06:00
else :
return backend_response , backend_response_status_code
2023-09-12 16:40:09 -06:00
2023-09-28 01:34:15 -06:00
def handle_ratelimited ( self , do_log : bool = True ) :
2023-09-27 14:48:47 -06:00
# TODO: return a simulated OpenAI error message
# Ratelimited: you are only allowed to have {opts.simultaneous_requests_per_ip} simultaneous requests at a time. Please complete your other requests before sending another.
return ' Ratelimited ' , 429
2023-09-12 16:40:09 -06:00
2023-09-28 00:11:34 -06:00
def handle_error ( self , error_msg : str , error_type : str = ' error ' ) - > Tuple [ flask . Response , int ] :
2023-09-27 14:48:47 -06:00
# TODO: return a simulated OpenAI error message
2023-09-24 21:45:30 -06:00
return jsonify ( {
" error " : {
" message " : " Invalid request, check your parameters and try again. " ,
" type " : " invalid_request_error " ,
" param " : None ,
" code " : None
}
} ) , 400