import sqlite3 import time from flask import jsonify, request from llm_server.routes.stats import SemaphoreCheckerThread from . import bp from ..cache import redis from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..queue import priority_queue from ... import opts from ...database import log_prompt from ...helpers import safe_list_get DEFAULT_PRIORITY = 9999 @bp.route('/generate', methods=['POST']) def generate(): start_time = time.time() request_valid_json, request_json_body = validate_json(request.data) if not request_valid_json: return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400 if request.headers.get('cf-connecting-ip'): client_ip = request.headers.get('cf-connecting-ip') elif request.headers.get('x-forwarded-for'): client_ip = request.headers.get('x-forwarded-for') else: client_ip = request.remote_addr SemaphoreCheckerThread.recent_prompters[client_ip] = time.time() parameters = request_json_body.copy() del parameters['prompt'] token = request.headers.get('X-Api-Key') priority = None if token: conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,)) result = cursor.fetchone() if result: priority = result[0] conn.close() if priority is None: priority = DEFAULT_PRIORITY else: print(f'Token {token} was given priority {priority}.') # success, response, error_msg = generator(request_json_body) event = priority_queue.put((request_json_body, client_ip, token, parameters), priority) event.wait() success, response, error_msg = event.data # Add the elapsed time to a global list end_time = time.time() elapsed_time = end_time - start_time # print('elapsed:', elapsed_time) with wait_in_queue_elapsed_lock: wait_in_queue_elapsed.append((end_time, elapsed_time)) if not success: if opts.mode == 'oobabooga': backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error') response_json_body = { 'results': [ { 'text': backend_response, } ], } else: raise Exception log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time, parameters, dict(request.headers), response.status_code) return jsonify({ 'code': 500, 'error': 'failed to reach backend', **response_json_body }), 200 response_valid_json, response_json_body = validate_json(response) if response_valid_json: redis.incr('proompts') backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') if not backend_response: if opts.mode == 'oobabooga': backend_response = format_sillytavern_err( f'Backend (oobabooga) returned an empty string. This can happen when your parameters are incorrect. Make sure your context size is no greater than {opts.context_size}. Furthermore, oobabooga does not support concurrent requests so all users have to wait in line and the backend server may have glitched for a moment. Please try again.', 'error') response_json_body['results'][0]['text'] = backend_response else: raise Exception log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time, parameters, dict(request.headers), response.status_code) return jsonify({ **response_json_body }), 200 else: if opts.mode == 'oobabooga': backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') response_json_body = { 'results': [ { 'text': backend_response, } ], } else: raise Exception log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time, parameters, dict(request.headers), response.status_code) return jsonify({ 'code': 500, 'error': 'the backend did not return valid JSON', **response_json_body }), 200