[DiffusionPipeline.from_pretrained] add warning when passing unused k… (#870)

[DiffusionPipeline.from_pretrained] add warning when passing unused kwargs
This commit is contained in:
Patrick von Platen 2022-10-20 13:30:01 +02:00 committed by GitHub
parent 4a76e5d49b
commit db19a9d9d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 58 additions and 2 deletions

View File

@ -350,6 +350,7 @@ class DiffusionPipeline(ConfigMixin):
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
@ -367,6 +368,7 @@ class DiffusionPipeline(ConfigMixin):
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
@ -439,7 +441,10 @@ class DiffusionPipeline(ConfigMixin):
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
if len(unused_kwargs) > 0:
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
init_kwargs = {}

View File

@ -1,9 +1,11 @@
import inspect
import logging
import os
import random
import re
import unittest
from distutils.util import strtobool
from io import StringIO
from pathlib import Path
from typing import Union
@ -284,3 +286,42 @@ def pytest_terminal_summary_main(tr, id):
tr._tw = orig_writer
tr.reportchars = orig_reportchars
config.option.tbstyle = orig_tbstyle
class CaptureLogger:
"""
Args:
Context manager to capture `logging` streams
logger: 'logging` logger object
Returns:
The captured output is available via `self.out`
Example:
```python
>>> from diffusers import logging
>>> from diffusers.testing_utils import CaptureLogger
>>> msg = "Testing 1, 2, 3"
>>> logging.set_verbosity_info()
>>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
>>> with CaptureLogger(logger) as cl:
... logger.info(msg)
>>> assert cl.out, msg + "\n"
```
"""
def __init__(self, logger):
self.logger = logger
self.io = StringIO()
self.sh = logging.StreamHandler(self.io)
self.out = ""
def __enter__(self):
self.logger.addHandler(self.sh)
return self
def __exit__(self, *exc):
self.logger.removeHandler(self.sh)
self.out = self.io.getvalue()
def __repr__(self):
return f"captured: {self.out}\n"

View File

@ -51,11 +51,12 @@ from diffusers import (
UNet2DConditionModel,
UNet2DModel,
VQModel,
logging,
)
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
from diffusers.utils.testing_utils import get_tests_dir
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
from packaging import version
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@ -1473,6 +1474,15 @@ class PipelineTesterMixin(unittest.TestCase):
# is not downloaded, but all the expected ones
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
def test_warning_unused_kwargs(self):
model_id = "hf-internal-testing/unet-pipeline-dummy"
logger = logging.get_logger("diffusers.pipeline_utils")
with tempfile.TemporaryDirectory() as tmpdirname:
with CaptureLogger(logger) as cap_logger:
DiffusionPipeline.from_pretrained(model_id, not_used=True, cache_dir=tmpdirname, force_download=True)
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
@property
def dummy_safety_checker(self):
def check(images, *args, **kwargs):