50 lines
2.2 KiB
Python
50 lines
2.2 KiB
Python
|
import sys
|
||
|
|
||
|
from flask import jsonify
|
||
|
|
||
|
from llm_server.database import log_prompt
|
||
|
from llm_server.helpers import indefinite_article
|
||
|
from llm_server.llm.llm_backend import LLMBackend
|
||
|
from llm_server.routes.helpers.client import format_sillytavern_err
|
||
|
from llm_server.routes.helpers.http import validate_json
|
||
|
|
||
|
|
||
|
class HfTextgenLLMBackend(LLMBackend):
|
||
|
def handle_response(self, success, response, error_msg, client_ip, token, prompt, elapsed_time, parameters, headers):
|
||
|
response_valid_json, response_json_body = validate_json(response)
|
||
|
backend_err = False
|
||
|
try:
|
||
|
response_status_code = response.status_code
|
||
|
except:
|
||
|
response_status_code = 0
|
||
|
|
||
|
if response_valid_json:
|
||
|
backend_response = response_json_body.get('generated_text', '')
|
||
|
|
||
|
if response_json_body.get('error'):
|
||
|
backend_err = True
|
||
|
error_type = response_json_body.get('error_type')
|
||
|
error_type_string = f'returned {indefinite_article(error_type)} {error_type} error'
|
||
|
backend_response = format_sillytavern_err(
|
||
|
f'Backend (hf-textgen) {error_type_string}: {response_json_body.get("error")}',
|
||
|
f'HTTP CODE {response_status_code}'
|
||
|
)
|
||
|
|
||
|
log_prompt(client_ip, token, prompt, backend_response, elapsed_time if not backend_err else None, parameters, headers, response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err)
|
||
|
return jsonify({
|
||
|
'results': [{'text': backend_response}]
|
||
|
}), 200
|
||
|
else:
|
||
|
backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error')
|
||
|
log_prompt(client_ip, token, prompt, backend_response, elapsed_time, parameters, headers, response.status_code, is_error=True)
|
||
|
return jsonify({
|
||
|
'code': 500,
|
||
|
'msg': 'the backend did not return valid JSON',
|
||
|
'results': [{'text': backend_response}]
|
||
|
}), 200
|
||
|
|
||
|
def validate_params(self, params_dict: dict):
|
||
|
if params_dict.get('typical_p', 0) > 0.998:
|
||
|
return False, '`typical_p` must be less than 0.999'
|
||
|
return True, None
|