import sqlite3 import time from flask import jsonify, request from llm_server.routes.stats import SemaphoreCheckerThread from . import bp from ..cache import redis from ..helpers.client import format_sillytavern_err from ..helpers.http import validate_json from ..queue import priority_queue from ... import opts from ...database import log_prompt from ...helpers import safe_list_get, indefinite_article DEFAULT_PRIORITY = 9999 # TODO: clean this up and make the ooba vs hf-textgen more object-oriented @bp.route('/generate', methods=['POST']) def generate(): start_time = time.time() request_valid_json, request_json_body = validate_json(request.data) if not request_valid_json: return jsonify({'code': 400, 'error': 'Invalid JSON'}), 400 if request.headers.get('cf-connecting-ip'): client_ip = request.headers.get('cf-connecting-ip') elif request.headers.get('x-forwarded-for'): client_ip = request.headers.get('x-forwarded-for').split(',')[0] else: client_ip = request.remote_addr SemaphoreCheckerThread.recent_prompters[client_ip] = time.time() parameters = request_json_body.copy() del parameters['prompt'] token = request.headers.get('X-Api-Key') priority = None if token: conn = sqlite3.connect(opts.database_path) cursor = conn.cursor() cursor.execute("SELECT priority FROM token_auth WHERE token = ?", (token,)) result = cursor.fetchone() if result: priority = result[0] conn.close() if priority is None: priority = DEFAULT_PRIORITY else: print(f'Token {token} was given priority {priority}.') queued_ip_count = redis.get_dict('queued_ip_count').get(client_ip, 0) + redis.get_dict('processing_ips').get(client_ip, 0) if queued_ip_count < opts.ip_in_queue_max or priority == 0: event = priority_queue.put((request_json_body, client_ip, token, parameters), priority) else: event = None if not event: backend_response = format_sillytavern_err(f'Ratelimited: you are only allowed to have {opts.ip_in_queue_max} simultaneous requests at a time. Please complete your other requests before sending another.', 'error') response_json_body = { 'results': [ { 'text': backend_response, } ], } log_prompt(client_ip, token, request_json_body['prompt'], backend_response, None, parameters, dict(request.headers), 429, is_error=True) return jsonify({ **response_json_body }), 200 event.wait() success, response, error_msg = event.data end_time = time.time() elapsed_time = end_time - start_time # Be extra careful when getting attributes from the response object try: response_status_code = response.status_code except: response_status_code = 0 # TODO: why is this if block sitting here if (not success or not response) and opts.mode == 'oobabooga': # Ooba doesn't return any error messages backend_response = format_sillytavern_err(f'Failed to reach the backend ({opts.mode}): {error_msg}', 'error') response_json_body = { 'results': [ { 'text': backend_response, } ], } log_prompt(client_ip, token, request_json_body['prompt'], backend_response, None, parameters, dict(request.headers), response if response else 0, is_error=True) return jsonify({ 'code': 500, 'error': 'failed to reach backend', **response_json_body }), 200 response_valid_json, response_json_body = validate_json(response) backend_err = False # Return the result to the client if response_valid_json: if opts.mode == 'oobabooga': backend_response = safe_list_get(response_json_body.get('results', []), 0, {}).get('text') if not backend_response: backend_err = True backend_response = format_sillytavern_err( f'Backend (oobabooga) returned an empty string. This is usually due to an error on the backend during inference. Please check your parameters and try again.', 'error') response_json_body['results'][0]['text'] = backend_response elif opts.mode == 'hf-textgen': backend_response = response_json_body.get('generated_text', '') if response_json_body.get('error'): backend_err = True error_type = response_json_body.get('error_type') error_type_string = 'returned an error' if opts.mode == 'oobabooga' else f'returned {indefinite_article(error_type)} {error_type} error' backend_response = format_sillytavern_err( f'Backend ({opts.mode}) {error_type_string}: {response_json_body.get("error")}', f'HTTP CODE {response_status_code}') response_json_body = { 'results': [ { 'text': backend_response } ] } else: response_json_body = { 'results': [ { 'text': backend_response } ] } else: raise Exception if not backend_err: redis.incr('proompts') log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time if not backend_err else None, parameters, dict(request.headers), response_status_code, response_json_body.get('details', {}).get('generated_tokens'), is_error=backend_err) return jsonify({ **response_json_body }), 200 else: backend_response = format_sillytavern_err(f'The backend did not return valid JSON.', 'error') response_json_body = { 'results': [ { 'text': backend_response, } ], } log_prompt(client_ip, token, request_json_body['prompt'], backend_response, elapsed_time, parameters, dict(request.headers), response.status_code, is_error=True) return jsonify({ 'code': 500, 'error': 'the backend did not return valid JSON', **response_json_body }), 200