from flask import jsonify, request from llm_server.routes.stats import concurrent_semaphore, proompts from . import bp from ..helpers.client import format_sillytavern_err from ..helpers.http import cache_control, validate_json from ... import opts from ...database import log_prompt from ...helpers import safe_list_get def generator(request_json_body): if opts.mode == 'oobabooga': from ...llm.oobabooga.generate import generate return generate(request_json_body) elif opts.mode == 'hf-textgen': from ...llm.hf_textgen.generate import generate return generate(request_json_body) else: raise Exception @bp.route('/generate', methods=['POST']) @cache_control(-1) def generate(): request_valid_json, request_json_body = validate_json(request.data) if not request_valid_json: return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400 with concurrent_semaphore: if request.headers.get('cf-connecting-ip'): client_ip = request.headers.get('cf-connecting-ip') elif request.headers.get('x-forwarded-for'): client_ip = request.headers.get('x-forwarded-for') else: client_ip = request.remote_addr token = request.headers.get('X-Api-Key') parameters = request_json_body.copy() del parameters['prompt'] success, response, error_msg = generator(request_json_body) if not success: if opts.mode == 'oobabooga': backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error') response_json_body = { 'results': [ { 'text': backend_response, } ], } else: raise Exception log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) return jsonify({ 'code': 500, 'error': 'failed to reach backend', **response_json_body }), 200 response_valid_json, response_json_body = validate_json(response) if response_valid_json: proompts.increment() backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') if not backend_response: if opts.mode == 'oobabooga': backend_response = format_sillytavern_err(f'Backend ({opts.mode}) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.context_size}.', 'error') response_json_body['results'][0]['text'] = backend_response else: raise Exception log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) print(response_json_body) return jsonify({ **response_json_body }), 200 else: if opts.mode == 'oobabooga': backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') response_json_body = { 'results': [ { 'text': backend_response, } ], } else: raise Exception log_prompt(opts.database_path, client_ip, token, request_json_body['prompt'], backend_response, parameters, dict(request.headers), response.status_code) return jsonify({ 'code': 500, 'error': 'the backend did not return valid JSON', **response_json_body }), 200 # @openai_bp.route('/chat/completions', methods=['POST']) # def generate_openai(): # print(request.data) # return '', 200