[DiffusionPipeline.from_pretrained] add warning when passing unused k… (#870)

[DiffusionPipeline.from_pretrained] add warning when passing unused kwargs
This commit is contained in:
Patrick von Platen 2022-10-20 13:30:01 +02:00 committed by GitHub
parent 4a76e5d49b
commit db19a9d9d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 2 deletions

View File

@ -350,6 +350,7 @@ class DiffusionPipeline(ConfigMixin):
""" """
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False) local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None) use_auth_token = kwargs.pop("use_auth_token", None)
@ -367,6 +368,7 @@ class DiffusionPipeline(ConfigMixin):
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
resume_download=resume_download, resume_download=resume_download,
force_download=force_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
@ -439,7 +441,10 @@ class DiffusionPipeline(ConfigMixin):
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
init_kwargs = {} init_kwargs = {}

View File

@ -1,9 +1,11 @@
import inspect import inspect
import logging
import os import os
import random import random
import re import re
import unittest import unittest
from distutils.util import strtobool from distutils.util import strtobool
from io import StringIO
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@ -284,3 +286,42 @@ def pytest_terminal_summary_main(tr, id):
tr._tw = orig_writer tr._tw = orig_writer
tr.reportchars = orig_reportchars tr.reportchars = orig_reportchars
config.option.tbstyle = orig_tbstyle config.option.tbstyle = orig_tbstyle
class CaptureLogger:
"""
Args:
Context manager to capture `logging` streams
logger: 'logging` logger object
Returns:
The captured output is available via `self.out`
Example:
```python
>>> from diffusers import logging
>>> from diffusers.testing_utils import CaptureLogger
>>> msg = "Testing 1, 2, 3"
>>> logging.set_verbosity_info()
>>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
>>> with CaptureLogger(logger) as cl:
... logger.info(msg)
>>> assert cl.out, msg + "\n"
```
"""
def __init__(self, logger):
self.logger = logger
self.io = StringIO()
self.sh = logging.StreamHandler(self.io)
self.out = ""
def __enter__(self):
self.logger.addHandler(self.sh)
return self
def __exit__(self, *exc):
self.logger.removeHandler(self.sh)
self.out = self.io.getvalue()
def __repr__(self):
return f"captured: {self.out}\n"

View File

@ -51,11 +51,12 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
UNet2DModel, UNet2DModel,
VQModel, VQModel,
logging,
) )
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import get_tests_dir from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
from packaging import version from packaging import version
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@ -1473,6 +1474,15 @@ class PipelineTesterMixin(unittest.TestCase):
# is not downloaded, but all the expected ones # is not downloaded, but all the expected ones
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy")) assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
def test_warning_unused_kwargs(self):
model_id = "hf-internal-testing/unet-pipeline-dummy"
logger = logging.get_logger("diffusers.pipeline_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
with CaptureLogger(logger) as cap_logger:
DiffusionPipeline.from_pretrained(model_id, not_used=True, cache_dir=tmpdirname, force_download=True)
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
@property @property
def dummy_safety_checker(self): def dummy_safety_checker(self):
def check(images, *args, **kwargs): def check(images, *args, **kwargs):