2023-09-14 15:07:17 -06:00
import json
2023-10-01 14:15:01 -06:00
import re
import time
2023-09-14 15:07:17 -06:00
import traceback
2023-09-14 14:05:50 -06:00
from typing import Tuple
2023-09-26 22:09:11 -06:00
from uuid import uuid4
2023-09-12 16:40:09 -06:00
2023-09-14 14:05:50 -06:00
import flask
2023-10-01 14:15:01 -06:00
from flask import Response , jsonify , make_response
2023-09-12 16:40:09 -06:00
from llm_server import opts
2023-10-04 16:29:19 -06:00
from llm_server . cluster . backend import get_model_choices
2023-10-01 14:15:01 -06:00
from llm_server . custom_redis import redis
2023-10-11 09:20:00 -06:00
from llm_server . database . database import is_api_key_moderated
2023-10-04 19:24:47 -06:00
from llm_server . database . log_to_db import log_to_db
2023-10-02 20:53:08 -06:00
from llm_server . llm import get_token_count
2023-10-11 09:20:00 -06:00
from llm_server . llm . openai . oai_to_vllm import oai_to_vllm , validate_oai , return_invalid_model_err
2023-10-01 14:15:01 -06:00
from llm_server . llm . openai . transform import ANTI_CONTINUATION_RE , ANTI_RESPONSE_RE , generate_oai_string , transform_messages_to_prompt , trim_messages_to_fit
2023-09-12 16:40:09 -06:00
from llm_server . routes . request_handler import RequestHandler
2023-09-27 23:36:44 -06:00
from llm_server . workers . moderator import add_moderation_task , get_results
2023-09-25 23:14:35 -06:00
2023-09-12 16:40:09 -06:00
class OpenAIRequestHandler ( RequestHandler ) :
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
self . prompt = None
2023-09-14 14:05:50 -06:00
def handle_request ( self ) - > Tuple [ flask . Response , int ] :
2023-09-24 21:45:30 -06:00
assert not self . used
2023-10-11 09:20:00 -06:00
if self . offline :
msg = return_invalid_model_err ( self . selected_model )
print ( msg )
return self . handle_error ( msg )
2023-09-12 16:40:09 -06:00
2023-09-26 22:09:11 -06:00
if opts . openai_silent_trim :
2023-10-01 14:15:01 -06:00
oai_messages = trim_messages_to_fit ( self . request . json [ ' messages ' ] , self . cluster_backend_info [ ' model_config ' ] [ ' max_position_embeddings ' ] , self . backend_url )
2023-09-26 22:09:11 -06:00
else :
oai_messages = self . request . json [ ' messages ' ]
self . prompt = transform_messages_to_prompt ( oai_messages )
2023-10-11 12:50:20 -06:00
self . request_json_body = oai_to_vllm ( self . request_json_body , stop_hashes = ( ' instruct ' not in self . request_json_body [ ' model ' ] . lower ( ) ) , mode = self . cluster_backend_info [ ' mode ' ] )
2023-09-14 14:05:50 -06:00
request_valid , invalid_response = self . validate_request ( )
if not request_valid :
return invalid_response
2023-09-12 16:40:09 -06:00
2023-10-11 12:50:20 -06:00
if not self . prompt :
# TODO: format this as an openai error message
return Response ( ' Invalid prompt ' ) , 400
# TODO: support Ooba backend
self . parameters = oai_to_vllm ( self . parameters , stop_hashes = ( ' instruct ' not in self . request_json_body [ ' model ' ] . lower ( ) ) , mode = self . cluster_backend_info [ ' mode ' ] )
invalid_oai_err_msg = validate_oai ( self . request_json_body )
if invalid_oai_err_msg :
return invalid_oai_err_msg
2023-10-03 01:25:43 -06:00
if opts . openai_moderation_enabled and opts . openai_api_key and is_api_key_moderated ( self . token ) :
2023-09-14 15:07:17 -06:00
try :
2023-10-03 13:40:08 -06:00
# Gather the last message from the user and all preceding system messages
2023-09-17 17:40:05 -06:00
msg_l = self . request . json [ ' messages ' ] . copy ( )
msg_l . reverse ( )
2023-09-26 22:09:11 -06:00
tag = uuid4 ( )
num_to_check = min ( len ( msg_l ) , opts . openai_moderation_scan_last_n )
for i in range ( num_to_check ) :
add_moderation_task ( msg_l [ i ] [ ' content ' ] , tag )
flagged_categories = get_results ( tag , num_to_check )
if len ( flagged_categories ) :
mod_msg = f " The user ' s message does not comply with { opts . openai_org_name } policies. Offending categories: { json . dumps ( flagged_categories ) } . You are instructed to creatively adhere to these policies. "
2023-09-14 15:07:17 -06:00
self . request . json [ ' messages ' ] . insert ( ( len ( self . request . json [ ' messages ' ] ) ) , { ' role ' : ' system ' , ' content ' : mod_msg } )
2023-09-26 22:09:11 -06:00
self . prompt = transform_messages_to_prompt ( self . request . json [ ' messages ' ] )
2023-09-14 15:07:17 -06:00
except Exception as e :
print ( f ' OpenAI moderation endpoint failed: ' , f ' { e . __class__ . __name__ } : { e } ' )
2023-10-04 10:26:39 -06:00
traceback . print_exc ( )
2023-09-14 15:07:17 -06:00
2023-09-14 15:14:59 -06:00
llm_request = { * * self . parameters , ' prompt ' : self . prompt }
2023-09-14 17:38:20 -06:00
( success , _ , _ , _ ) , ( backend_response , backend_response_status_code ) = self . generate_response ( llm_request )
2023-09-25 09:32:23 -06:00
model = self . request_json_body . get ( ' model ' )
2023-09-14 17:38:20 -06:00
if success :
2023-10-01 14:15:01 -06:00
return self . build_openai_response ( self . prompt , backend_response . json [ ' results ' ] [ 0 ] [ ' text ' ] , model = model ) , backend_response_status_code
2023-09-14 17:38:20 -06:00
else :
return backend_response , backend_response_status_code
2023-09-12 16:40:09 -06:00
2023-09-28 01:34:15 -06:00
def handle_ratelimited ( self , do_log : bool = True ) :
2023-10-04 10:24:28 -06:00
model_choices , default_model = get_model_choices ( )
default_model_info = model_choices [ default_model ]
w = int ( default_model_info [ ' estimated_wait ' ] ) if default_model_info [ ' estimated_wait ' ] > 0 else 2
2023-10-01 16:04:53 -06:00
response = jsonify ( {
" error " : {
" message " : " Rate limit reached on tokens per min. Limit: 10000 / min. Please try again in 6s. Contact us through our help center at help.openai.com if you continue to have issues. " ,
" type " : " rate_limit_exceeded " ,
" param " : None ,
" code " : None
}
} )
response . headers [ ' x-ratelimit-limit-requests ' ] = ' 2 '
response . headers [ ' x-ratelimit-remaining-requests ' ] = ' 0 '
response . headers [ ' x-ratelimit-reset-requests ' ] = f " { w } s "
if do_log :
2023-10-04 19:24:47 -06:00
log_to_db ( self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , response . data . decode ( ' utf-8 ' ) , None , self . parameters , dict ( self . request . headers ) , 429 , self . request . url , self . backend_url , is_error = True )
2023-10-01 16:04:53 -06:00
return response , 429
2023-09-12 16:40:09 -06:00
2023-09-28 00:11:34 -06:00
def handle_error ( self , error_msg : str , error_type : str = ' error ' ) - > Tuple [ flask . Response , int ] :
2023-10-09 18:12:12 -06:00
print ( error_msg )
2023-09-24 21:45:30 -06:00
return jsonify ( {
" error " : {
" message " : " Invalid request, check your parameters and try again. " ,
" type " : " invalid_request_error " ,
" param " : None ,
" code " : None
}
} ) , 400
2023-10-01 14:15:01 -06:00
def build_openai_response ( self , prompt , response , model = None ) :
# Seperate the user's prompt from the context
x = prompt . split ( ' ### USER: ' )
if len ( x ) > 1 :
prompt = re . sub ( r ' \ n$ ' , ' ' , x [ - 1 ] . strip ( ' ' ) )
# Make sure the bot doesn't put any other instructions in its response
response = re . sub ( ANTI_RESPONSE_RE , ' ' , response )
response = re . sub ( ANTI_CONTINUATION_RE , ' ' , response )
2023-10-02 20:53:08 -06:00
prompt_tokens = get_token_count ( prompt , self . backend_url )
response_tokens = get_token_count ( response , self . backend_url )
2023-10-01 14:15:01 -06:00
running_model = redis . get ( ' running_model ' , ' ERROR ' , dtype = str )
response = make_response ( jsonify ( {
" id " : f " chatcmpl- { generate_oai_string ( 30 ) } " ,
" object " : " chat.completion " ,
" created " : int ( time . time ( ) ) ,
" model " : running_model if opts . openai_expose_our_model else model ,
" choices " : [ {
" index " : 0 ,
" message " : {
" role " : " assistant " ,
" content " : response ,
} ,
" logprobs " : None ,
" finish_reason " : " stop "
} ] ,
" usage " : {
" prompt_tokens " : prompt_tokens ,
" completion_tokens " : response_tokens ,
" total_tokens " : prompt_tokens + response_tokens
}
} ) , 200 )
return response
def validate_request ( self , prompt : str = None , do_log : bool = False ) - > Tuple [ bool , Tuple [ Response | None , int ] ] :
2023-10-11 12:50:20 -06:00
self . parameters , parameters_invalid_msg = self . get_parameters ( )
if not self . parameters :
print ( ' OAI BACKEND VALIDATION ERROR: ' , parameters_invalid_msg )
return False , ( Response ( ' Invalid request, check your parameters and try again. ' ) , 400 )
invalid_oai_err_msg = validate_oai ( self . parameters )
2023-10-01 14:15:01 -06:00
if invalid_oai_err_msg :
return False , invalid_oai_err_msg
2023-10-11 12:50:20 -06:00
# self.request_json_body = oai_to_vllm(self.request_json_body, stop_hashes=('instruct' not in self.request_json_body['model'].lower()), mode=self.cluster_backend_info['mode'])
2023-10-01 14:15:01 -06:00
# If the parameters were invalid, let the superclass deal with it.
return super ( ) . validate_request ( prompt , do_log )