from flask import jsonify from ..llm_backend import LLMBackend from ...database.database import log_prompt from ...helpers import safe_list_get from ...routes.cache import redis from ...routes.helpers.client import format_sillytavern_err from ...routes.helpers.http import validate_json class OobaboogaBackend(LLMBackend): default_params = {} def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers): raise NotImplementedError('need to implement default_params') backend_err = False response_valid_json, response_json_body = validate_json(response) if response: try: # Be extra careful when getting attributes from the response object response_status_code = response.status_code except: response_status_code = 0 else: response_status_code = None # =============================================== # We encountered an error if not success or not response or error_msg: if not error_msg or error_msg == '': error_msg = 'Unknown error.' else: error_msg = error_msg.strip('.') + '.' backend_response = format_sillytavern_err(error_msg, 'error') log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True) return jsonify({ 'code': 500, 'msg': error_msg, 'results': [{'text': backend_response}] }), 400 # =============================================== if response_valid_json: backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') if not backend_response: # Ooba doesn't return any error messages so we will just tell the client an error occurred backend_err = True backend_response = format_sillytavern_err( f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.', 'error') response_json_body['results'][0]['text'] = backend_response if not backend_err: redis.incr('proompts') log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) return jsonify({ **response_json_body }), 200 else: backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True) return jsonify({ 'code': 500, 'msg': 'the backend did not return valid JSON', 'results': [{'text': backend_response}] }), 400 def validate_params(self, params_dict: dict): # No validation required return True, None def get_parameters(self, parameters): del parameters['prompt'] return parameters