79 lines
3.4 KiB
Python
79 lines
3.4 KiB
Python
from flask import jsonify
|
|
|
|
from ..llm_backend import LLMBackend
|
|
from ...database.database import log_prompt
|
|
from ...helpers import safe_list_get
|
|
from llm_server.custom_redis import redis
|
|
from ...routes.helpers.client import format_sillytavern_err
|
|
from ...routes.helpers.http import validate_json
|
|
|
|
|
|
class OobaboogaBackend(LLMBackend):
|
|
default_params = {}
|
|
|
|
def handle_response(self, success, request, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
|
raise NotImplementedError('need to implement default_params')
|
|
|
|
backend_err = False
|
|
response_valid_json, response_json_body = validate_json(response)
|
|
if response:
|
|
try:
|
|
# Be extra careful when getting attributes from the response object
|
|
response_status_code = response.status_code
|
|
except:
|
|
response_status_code = 0
|
|
else:
|
|
response_status_code = None
|
|
|
|
# ===============================================
|
|
|
|
# We encountered an error
|
|
if not success or not response or error_msg:
|
|
if not error_msg or error_msg == '':
|
|
error_msg = 'Unknown error.'
|
|
else:
|
|
error_msg = error_msg.strip('.') + '.'
|
|
backend_response = format_sillytavern_err(error_msg, 'error')
|
|
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response_status_code, request.url, is_error=True)
|
|
return jsonify({
|
|
'code': 500,
|
|
'msg': error_msg,
|
|
'results': [{'text': backend_response}]
|
|
}), 400
|
|
|
|
# ===============================================
|
|
|
|
if response_valid_json:
|
|
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
|
|
if not backend_response:
|
|
# Ooba doesn't return any error messages so we will just tell the client an error occurred
|
|
backend_err = True
|
|
backend_response = format_sillytavern_err(
|
|
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
|
|
'error')
|
|
response_json_body['results'][0]['text'] = backend_response
|
|
|
|
if not backend_err:
|
|
redis.incr('proompts')
|
|
|
|
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, request.url, response_tokens=response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
|
return jsonify({
|
|
**response_json_body
|
|
}), 200
|
|
else:
|
|
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
|
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, request.url, is_error=True)
|
|
return jsonify({
|
|
'code': 500,
|
|
'msg': 'the backend did not return valid JSON',
|
|
'results': [{'text': backend_response}]
|
|
}), 400
|
|
|
|
def validate_params(self, params_dict: dict):
|
|
# No validation required
|
|
return True, None
|
|
|
|
def get_parameters(self, parameters):
|
|
del parameters['prompt']
|
|
return parameters
|