[DiffusionPipeline.from_pretrained] add warning when passing unused k… (#870)
[DiffusionPipeline.from_pretrained] add warning when passing unused kwargs
This commit is contained in:
parent
4a76e5d49b
commit
db19a9d9d7
|
@ -350,6 +350,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
|
@ -367,6 +368,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
|
@ -439,7 +441,10 @@ class DiffusionPipeline(ConfigMixin):
|
|||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
|
||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import unittest
|
||||
from distutils.util import strtobool
|
||||
from io import StringIO
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
@ -284,3 +286,42 @@ def pytest_terminal_summary_main(tr, id):
|
|||
tr._tw = orig_writer
|
||||
tr.reportchars = orig_reportchars
|
||||
config.option.tbstyle = orig_tbstyle
|
||||
|
||||
|
||||
class CaptureLogger:
|
||||
"""
|
||||
Args:
|
||||
Context manager to capture `logging` streams
|
||||
logger: 'logging` logger object
|
||||
Returns:
|
||||
The captured output is available via `self.out`
|
||||
Example:
|
||||
```python
|
||||
>>> from diffusers import logging
|
||||
>>> from diffusers.testing_utils import CaptureLogger
|
||||
|
||||
>>> msg = "Testing 1, 2, 3"
|
||||
>>> logging.set_verbosity_info()
|
||||
>>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
|
||||
>>> with CaptureLogger(logger) as cl:
|
||||
... logger.info(msg)
|
||||
>>> assert cl.out, msg + "\n"
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, logger):
|
||||
self.logger = logger
|
||||
self.io = StringIO()
|
||||
self.sh = logging.StreamHandler(self.io)
|
||||
self.out = ""
|
||||
|
||||
def __enter__(self):
|
||||
self.logger.addHandler(self.sh)
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.logger.removeHandler(self.sh)
|
||||
self.out = self.io.getvalue()
|
||||
|
||||
def __repr__(self):
|
||||
return f"captured: {self.out}\n"
|
||||
|
|
|
@ -51,11 +51,12 @@ from diffusers import (
|
|||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
VQModel,
|
||||
logging,
|
||||
)
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import get_tests_dir
|
||||
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
|
||||
from packaging import version
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
@ -1473,6 +1474,15 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
# is not downloaded, but all the expected ones
|
||||
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
|
||||
|
||||
def test_warning_unused_kwargs(self):
|
||||
model_id = "hf-internal-testing/unet-pipeline-dummy"
|
||||
logger = logging.get_logger("diffusers.pipeline_utils")
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
DiffusionPipeline.from_pretrained(model_id, not_used=True, cache_dir=tmpdirname, force_download=True)
|
||||
|
||||
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
|
||||
|
||||
@property
|
||||
def dummy_safety_checker(self):
|
||||
def check(images, *args, **kwargs):
|
||||
|
|
Loading…
Reference in New Issue