2023-09-11 20:47:19 -06:00
from typing import Tuple
import requests
2023-08-30 18:53:26 -06:00
from flask import jsonify
2023-09-11 20:47:19 -06:00
from . . . import opts
2023-08-30 18:53:26 -06:00
from . . . database import log_prompt
from . . . helpers import safe_list_get
from . . . routes . cache import redis
from . . . routes . helpers . client import format_sillytavern_err
from . . . routes . helpers . http import validate_json
from . . llm_backend import LLMBackend
class OobaboogaLLMBackend ( LLMBackend ) :
def handle_response ( self , success , response , error_msg , client_ip , token , prompt , elapsed_time , parameters , headers ) :
backend_err = False
response_valid_json , response_json_body = validate_json ( response )
try :
# Be extra careful when getting attributes from the response object
response_status_code = response . status_code
except :
response_status_code = 0
# ===============================================
# We encountered an error
if not success or not response :
backend_response = format_sillytavern_err ( f ' Failed to reach the backend (oobabooga): { error_msg } ' , ' error ' )
log_prompt ( client_ip , token , prompt , backend_response , None , parameters , headers , response if response else 0 , is_error = True )
return jsonify ( {
' code ' : 500 ,
' msg ' : ' failed to reach backend ' ,
' results ' : [ { ' text ' : backend_response } ]
} ) , 200
# ===============================================
if response_valid_json :
backend_response = safe_list_get ( response_json_body . get ( ' results ' , [ ] ) , 0 , { } ) . get ( ' text ' )
if not backend_response :
# Ooba doesn't return any error messages so we will just tell the client an error occurred
backend_err = True
backend_response = format_sillytavern_err (
f ' Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again. ' ,
' error ' )
response_json_body [ ' results ' ] [ 0 ] [ ' text ' ] = backend_response
if not backend_err :
redis . incr ( ' proompts ' )
log_prompt ( client_ip , token , prompt , backend_response , elapsed_time if not backend_err else None , parameters , headers , response_status_code , response_json_body . get ( ' details ' , { } ) . get ( ' generated_tokens ' ) , is_error = backend_err )
return jsonify ( {
* * response_json_body
} ) , 200
else :
backend_response = format_sillytavern_err ( f ' The backend did not return valid JSON. ' , ' error ' )
log_prompt ( client_ip , token , prompt , backend_response , elapsed_time , parameters , headers , response . status_code , is_error = True )
return jsonify ( {
' code ' : 500 ,
' msg ' : ' the backend did not return valid JSON ' ,
' results ' : [ { ' text ' : backend_response } ]
} ) , 200
def validate_params ( self , params_dict : dict ) :
# No validation required
return True , None
2023-09-11 20:47:19 -06:00
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
# backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
# r_json = backend_response.json()
# return r_json['result'], None
# except Exception as e:
# return False, e