2023-09-14 17:38:20 -06:00
from typing import Tuple
2023-09-12 16:40:09 -06:00
2023-09-14 17:38:20 -06:00
import flask
2023-09-27 14:48:47 -06:00
from flask import jsonify , request
2023-09-12 16:40:09 -06:00
from llm_server import opts
2023-09-20 20:30:31 -06:00
from llm_server . database . database import log_prompt
2023-09-12 16:40:09 -06:00
from llm_server . routes . helpers . client import format_sillytavern_err
from llm_server . routes . request_handler import RequestHandler
class OobaRequestHandler ( RequestHandler ) :
def __init__ ( self , * args , * * kwargs ) :
super ( ) . __init__ ( * args , * * kwargs )
2023-10-01 16:04:53 -06:00
def handle_request ( self , return_ok : bool = True ) :
2023-09-25 22:32:48 -06:00
assert not self . used
2023-09-12 16:40:09 -06:00
2023-09-14 14:05:50 -06:00
request_valid , invalid_response = self . validate_request ( )
if not request_valid :
return invalid_response
2023-09-12 16:40:09 -06:00
# Reconstruct the request JSON with the validated parameters and prompt.
prompt = self . request_json_body . get ( ' prompt ' , ' ' )
llm_request = { * * self . parameters , ' prompt ' : prompt }
2023-09-14 14:05:50 -06:00
_ , backend_response = self . generate_response ( llm_request )
2023-10-01 16:04:53 -06:00
if return_ok :
# Always return 200 so ST displays our error messages
return backend_response [ 0 ] , 200
else :
# The OpenAI route needs to detect 429 errors.
return backend_response
2023-09-12 16:40:09 -06:00
2023-09-28 01:34:15 -06:00
def handle_ratelimited ( self , do_log : bool = True ) :
2023-09-26 23:59:22 -06:00
msg = f ' Ratelimited: you are only allowed to have { opts . simultaneous_requests_per_ip } simultaneous requests at a time. Please complete your other requests before sending another. '
2023-09-27 14:48:47 -06:00
backend_response = self . handle_error ( msg )
2023-09-28 01:34:15 -06:00
if do_log :
2023-09-30 19:41:50 -06:00
log_prompt ( self . client_ip , self . token , self . request_json_body . get ( ' prompt ' , ' ' ) , backend_response [ 0 ] . data . decode ( ' utf-8 ' ) , None , self . parameters , dict ( self . request . headers ) , 429 , self . request . url , self . backend_url , is_error = True )
2023-10-01 16:04:53 -06:00
return backend_response [ 0 ] , 429
2023-09-27 14:48:47 -06:00
def handle_error ( self , error_msg : str , error_type : str = ' error ' ) - > Tuple [ flask . Response , int ] :
disable_st_error_formatting = request . headers . get ( ' LLM-ST-Errors ' , False ) == ' true '
2023-09-26 23:59:22 -06:00
if disable_st_error_formatting :
2023-09-27 14:48:47 -06:00
# TODO: how to format this
response_msg = error_msg
2023-09-26 23:59:22 -06:00
else :
2023-09-30 19:41:50 -06:00
response_msg = format_sillytavern_err ( error_msg , error_type = error_type , backend_url = self . backend_url )
2023-09-14 17:38:20 -06:00
return jsonify ( {
2023-09-27 14:48:47 -06:00
' results ' : [ { ' text ' : response_msg } ]
2023-09-27 14:36:49 -06:00
} ) , 200 # return 200 so we don't trigger an error message in the client's ST