[DiffusionPipeline.from_pretrained] add warning when passing unused k… (#870)
[DiffusionPipeline.from_pretrained] add warning when passing unused kwargs
This commit is contained in:
parent
4a76e5d49b
commit
db19a9d9d7
|
@ -350,6 +350,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
proxies = kwargs.pop("proxies", None)
|
proxies = kwargs.pop("proxies", None)
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
@ -367,6 +368,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
|
force_download=force_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
local_files_only=local_files_only,
|
local_files_only=local_files_only,
|
||||||
use_auth_token=use_auth_token,
|
use_auth_token=use_auth_token,
|
||||||
|
@ -439,7 +441,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
|
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
|
||||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||||
|
|
||||||
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
if len(unused_kwargs) > 0:
|
||||||
|
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
|
||||||
|
|
||||||
init_kwargs = {}
|
init_kwargs = {}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import unittest
|
import unittest
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
|
from io import StringIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
@ -284,3 +286,42 @@ def pytest_terminal_summary_main(tr, id):
|
||||||
tr._tw = orig_writer
|
tr._tw = orig_writer
|
||||||
tr.reportchars = orig_reportchars
|
tr.reportchars = orig_reportchars
|
||||||
config.option.tbstyle = orig_tbstyle
|
config.option.tbstyle = orig_tbstyle
|
||||||
|
|
||||||
|
|
||||||
|
class CaptureLogger:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
Context manager to capture `logging` streams
|
||||||
|
logger: 'logging` logger object
|
||||||
|
Returns:
|
||||||
|
The captured output is available via `self.out`
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> from diffusers import logging
|
||||||
|
>>> from diffusers.testing_utils import CaptureLogger
|
||||||
|
|
||||||
|
>>> msg = "Testing 1, 2, 3"
|
||||||
|
>>> logging.set_verbosity_info()
|
||||||
|
>>> logger = logging.get_logger("diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.py")
|
||||||
|
>>> with CaptureLogger(logger) as cl:
|
||||||
|
... logger.info(msg)
|
||||||
|
>>> assert cl.out, msg + "\n"
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger):
|
||||||
|
self.logger = logger
|
||||||
|
self.io = StringIO()
|
||||||
|
self.sh = logging.StreamHandler(self.io)
|
||||||
|
self.out = ""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.logger.addHandler(self.sh)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, *exc):
|
||||||
|
self.logger.removeHandler(self.sh)
|
||||||
|
self.out = self.io.getvalue()
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"captured: {self.out}\n"
|
||||||
|
|
|
@ -51,11 +51,12 @@ from diffusers import (
|
||||||
UNet2DConditionModel,
|
UNet2DConditionModel,
|
||||||
UNet2DModel,
|
UNet2DModel,
|
||||||
VQModel,
|
VQModel,
|
||||||
|
logging,
|
||||||
)
|
)
|
||||||
from diffusers.pipeline_utils import DiffusionPipeline
|
from diffusers.pipeline_utils import DiffusionPipeline
|
||||||
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
||||||
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
|
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, load_image, slow, torch_device
|
||||||
from diffusers.utils.testing_utils import get_tests_dir
|
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||||
|
@ -1473,6 +1474,15 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
# is not downloaded, but all the expected ones
|
# is not downloaded, but all the expected ones
|
||||||
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
|
assert not os.path.isfile(os.path.join(snapshot_dir, "big_array.npy"))
|
||||||
|
|
||||||
|
def test_warning_unused_kwargs(self):
|
||||||
|
model_id = "hf-internal-testing/unet-pipeline-dummy"
|
||||||
|
logger = logging.get_logger("diffusers.pipeline_utils")
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
with CaptureLogger(logger) as cap_logger:
|
||||||
|
DiffusionPipeline.from_pretrained(model_id, not_used=True, cache_dir=tmpdirname, force_download=True)
|
||||||
|
|
||||||
|
assert cap_logger.out == "Keyword arguments {'not_used': True} not recognized.\n"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dummy_safety_checker(self):
|
def dummy_safety_checker(self):
|
||||||
def check(images, *args, **kwargs):
|
def check(images, *args, **kwargs):
|
||||||
|
|
Loading…
Reference in New Issue