2023-08-30 18:53:26 -06:00
from flask import jsonify
2023-09-12 01:04:11 -06:00
from . . llm_backend import LLMBackend
2023-09-20 20:30:31 -06:00
from . . . database . database import log_prompt
2023-08-30 18:53:26 -06:00
from . . . helpers import safe_list_get
from . . . routes . cache import redis
from . . . routes . helpers . client import format_sillytavern_err
from . . . routes . helpers . http import validate_json
2023-09-12 16:40:09 -06:00
class OobaboogaBackend ( LLMBackend ) :
2023-09-14 14:05:50 -06:00
default_params = { }
2023-09-13 11:51:46 -06:00
def handle_response ( self , success , request , response , error_msg , client_ip , token , prompt , elapsed_time , parameters , headers ) :
2023-09-14 14:05:50 -06:00
raise NotImplementedError ( ' need to implement default_params ' )
2023-08-30 18:53:26 -06:00
backend_err = False
response_valid_json , response_json_body = validate_json ( response )
2023-09-14 14:05:50 -06:00
if response :
try :
# Be extra careful when getting attributes from the response object
response_status_code = response . status_code
except :
response_status_code = 0
else :
response_status_code = None
2023-08-30 18:53:26 -06:00
# ===============================================
2023-09-14 14:05:50 -06:00
2023-08-30 18:53:26 -06:00
# We encountered an error
2023-09-14 14:05:50 -06:00
if not success or not response or error_msg :
if not error_msg or error_msg == ' ' :
error_msg = ' Unknown error. '
else :
error_msg = error_msg . strip ( ' . ' ) + ' . '
backend_response = format_sillytavern_err ( error_msg , ' error ' )
log_prompt ( client_ip , token , prompt , backend_response , None , parameters , headers , response_status_code , request . url , is_error = True )
2023-08-30 18:53:26 -06:00
return jsonify ( {
' code ' : 500 ,
2023-09-14 14:05:50 -06:00
' msg ' : error_msg ,
2023-08-30 18:53:26 -06:00
' results ' : [ { ' text ' : backend_response } ]
2023-09-24 21:45:30 -06:00
} ) , 400
2023-09-14 14:05:50 -06:00
2023-08-30 18:53:26 -06:00
# ===============================================
if response_valid_json :
backend_response = safe_list_get ( response_json_body . get ( ' results ' , [ ] ) , 0 , { } ) . get ( ' text ' )
if not backend_response :
# Ooba doesn't return any error messages so we will just tell the client an error occurred
backend_err = True
backend_response = format_sillytavern_err (
f ' Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again. ' ,
' error ' )
response_json_body [ ' results ' ] [ 0 ] [ ' text ' ] = backend_response
if not backend_err :
redis . incr ( ' proompts ' )
2023-09-13 11:22:33 -06:00
log_prompt ( client_ip , token , prompt , backend_response , elapsed_time if not backend_err else None , parameters , headers , response_status_code , request . url , response_tokens = response_json_body . get ( ' details ' , { } ) . get ( ' generated_tokens ' ) , is_error = backend_err )
2023-08-30 18:53:26 -06:00
return jsonify ( {
* * response_json_body
} ) , 200
else :
backend_response = format_sillytavern_err ( f ' The backend did not return valid JSON. ' , ' error ' )
2023-09-13 11:22:33 -06:00
log_prompt ( client_ip , token , prompt , backend_response , elapsed_time , parameters , headers , response . status_code , request . url , is_error = True )
2023-08-30 18:53:26 -06:00
return jsonify ( {
' code ' : 500 ,
' msg ' : ' the backend did not return valid JSON ' ,
' results ' : [ { ' text ' : backend_response } ]
2023-09-24 21:45:30 -06:00
} ) , 400
2023-08-30 18:53:26 -06:00
def validate_params ( self , params_dict : dict ) :
# No validation required
return True , None
2023-09-11 20:47:19 -06:00
2023-09-12 01:04:11 -06:00
def get_parameters ( self , parameters ) :
del parameters [ ' prompt ' ]
return parameters