local-llm-server/llm_server/llm/oobabooga/ooba_backend.py

74 lines
3.4 KiB
Python

from flask import jsonify
from ..llm_backend import LLMBackend
from ...database import log_prompt
from ...helpers import safe_list_get
from ...routes.cache import redis
from ...routes.helpers.client import format_sillytavern_err
from ...routes.helpers.http import validate_json
class OobaboogaLLMBackend(LLMBackend):
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
backend_err = False
response_valid_json, response_json_body = validate_json(response)
try:
# Be extra careful when getting attributes from the response object
response_status_code = response.status_code
except:
response_status_code = 0
# ===============================================
# We encountered an error
if not success or not response:
backend_response = format_sillytavern_err(f'Failed to reach the backend (oobabooga): {error_msg}', 'error')
log_prompt(client_ip, token, prompt, backend_response, None, parameters, headers, response if response else 0, is_error=True)
return jsonify({
'code': 500,
'msg': 'failed to reach backend',
'results': [{'text': backend_response}]
}), 200
# ===============================================
if response_valid_json:
backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text')
if not backend_response:
# Ooba doesn't return any error messages so we will just tell the client an error occurred
backend_err = True
backend_response = format_sillytavern_err(
f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.',
'error')
response_json_body['results'][0]['text'] = backend_response
if not backend_err:
redis.incr('proompts')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
return jsonify({
**response_json_body
}), 200
else:
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, is_error=True)
return jsonify({
'code': 500,
'msg': 'the backend did not return valid JSON',
'results': [{'text': backend_response}]
}), 200
def validate_params(self, params_dict: dict):
# No validation required
return True, None
# def get_model_info(self) -> Tuple[dict | bool, Exception | None]:
# try:
# backend_response = requests.get(f'{opts.backend_url}/api/v1/model', timeout=3, verify=opts.verify_ssl)
# r_json = backend_response.json()
# return r_json['result'], None
# except Exception as e:
# return False, e
def get_parameters(self, parameters):
del parameters['prompt']
return parameters