From db19a9d9d7d8fafb372675c16fcdde6674d4e1ab Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 20 Oct 2022 13:30:01 +0200 Subject: [PATCH] =?UTF-8?q?[DiffusionPipeline.from=5Fpretrained]=20add=20w?= =?UTF-8?q?arning=20when=20passing=20unused=20k=E2=80=A6=20(#870)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit [DiffusionPipeline.from_pretrained] add warning when passing unused kwargs --- src/diffusers/pipeline_utils.py | 7 ++++- src/diffusers/utils/testing_utils.py | 41 ++++++++++++++++++++++++++++ tests/test_pipelines.py | 12 +++++++- 3 files changed, 58 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 1c787a41..5d186f35 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -350,6 +350,7 @@ class DiffusionPipeline(ConfigMixin): """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) resume_download = kwargs.pop("resume_download", False) + force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) @@ -367,6 +368,7 @@ class DiffusionPipeline(ConfigMixin): pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, + force_download=force_download, proxies=proxies, local_files_only=local_files_only, use_auth_token=use_auth_token, @@ -439,7 +441,10 @@ class DiffusionPipeline(ConfigMixin): expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs) + + if len(unused_kwargs) > 0: + logger.warning(f"Keyword arguments {unused_kwargs} not recognized.") init_kwargs = {} diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 682a7471..8289e01e 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -1,9 +1,11 @@ import inspect +import logging import os import random import re import unittest from distutils.util import strtobool +from io import StringIO from pathlib import Path from typing import Union @@ -284,3 +286,42 @@ def pytest_terminal_summary_main(tr, id): tr._tw = orig_writer tr.reportchars = orig_reportchars config.option.tbstyle = orig_tbstyle + + +class CaptureLogger: + """ + Args: + Context manager to capture `logging` streams + logger: 'logging` logger object + Returns: + The captured output is available via `self.out` + Example: + ```python + >>> from diffusers import logging + >>> from diffusers.testing_utils import CaptureLogger + + >>> msg = "Testing 1, 2, 3" + >>> logging.set_verbosity_info() + >>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py") + >>> with CaptureLogger(logger) as cl: + ... logger.info(msg) + >>> assert cl.out, msg + "\n" + ``` + """ + + def __init__(self, logger): + self.logger = logger + self.io = StringIO() + self.sh = logging.StreamHandler(self.io) + self.out = "" + + def __enter__(self): + self.logger.addHandler(self.sh) + return self + + def __exit__(self, *exc): + self.logger.removeHandler(self.sh) + self.out = self.io.getvalue() + + def __repr__(self): + return f"captured: {self.out}\n" diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 1de2e6f5..3efedd50 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -51,11 +51,12 @@ from diffusers import ( UNet2DConditionModel, UNet2DModel, VQModel, + logging, ) from diffusers.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device -from diffusers.utils.testing_utils import get_tests_dir +from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir from packaging import version from PIL import Image from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer @@ -1473,6 +1474,15 @@ class PipelineTesterMixin(unittest.TestCase): # is not downloaded, but all the expected ones assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) + def test_warning_unused_kwargs(self): + model_id = "hf-internal-testing/unet-pipeline-dummy" + logger = logging.get_logger("diffusers.pipeline_utils") + with tempfile.TemporaryDirectory() as tmpdirname: + with CaptureLogger(logger) as cap_logger: + DiffusionPipeline.from_pretrained(model_id, not_used=True, cache_dir=tmpdirname, force_download=True) + + assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n" + @property def dummy_safety_checker(self): def check(images, *args, **kwargs):