[Config] Add optional arguments (#1395)
* Optional Components * uP * finish * finish * finish * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up * Update src/diffusers/pipeline_utils.py * improve Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
e0e86b7470
commit
cbfed0c256
|
@ -129,10 +129,13 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
Class attributes:
|
||||
|
||||
- **config_name** ([`str`]) -- name of the config file that will store the class and module names of all
|
||||
- **config_name** (`str`) -- name of the config file that will store the class and module names of all
|
||||
components of the diffusion pipeline.
|
||||
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
|
||||
passed for the pipeline to function (should be overridden by subclasses).
|
||||
"""
|
||||
config_name = "model_index.json"
|
||||
_optional_components = []
|
||||
|
||||
def register_modules(self, **kwargs):
|
||||
# import it here to avoid circular import
|
||||
|
@ -184,12 +187,19 @@ class DiffusionPipeline(ConfigMixin):
|
|||
model_index_dict.pop("_diffusers_version")
|
||||
model_index_dict.pop("_module", None)
|
||||
|
||||
expected_modules, optional_kwargs = self._get_signature_keys(self)
|
||||
|
||||
def is_saveable_module(name, value):
|
||||
if name not in expected_modules:
|
||||
return False
|
||||
if name in self._optional_components and value[0] is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
if sub_model is None:
|
||||
# edge case for saving a pipeline with safety_checker=None
|
||||
continue
|
||||
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
|
@ -523,26 +533,27 @@ class DiffusionPipeline(ConfigMixin):
|
|||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
# extract them here
|
||||
expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"])
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# define init kwargs
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
# remove `null` components
|
||||
init_dict = {k: v for k, v in init_dict.items() if v[0] is not None}
|
||||
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(f"Keyword arguments {unused_kwargs} not recognized.")
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name is None:
|
||||
# edge case for when the pipeline was saved with safety_checker=None
|
||||
init_kwargs[name] = None
|
||||
continue
|
||||
|
||||
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if class_name.startswith("Flax"):
|
||||
class_name = class_name[4:]
|
||||
|
@ -570,7 +581,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||
f" {expected_class_obj}"
|
||||
)
|
||||
elif passed_class_obj[name] is None:
|
||||
elif passed_class_obj[name] is None and name not in pipeline_class._optional_components:
|
||||
logger.warning(
|
||||
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
|
||||
f" that this might lead to problems when using {pipeline_class} and is not recommended."
|
||||
|
@ -651,11 +662,13 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
# 4. Potentially add passed objects if expected
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_class_obj.keys()):
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||
for module in missing_modules:
|
||||
init_kwargs[module] = passed_class_obj[module]
|
||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||
elif len(missing_modules) > 0:
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys()))
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||
raise ValueError(
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
@ -664,6 +677,14 @@ class DiffusionPipeline(ConfigMixin):
|
|||
model = pipeline_class(**init_kwargs)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _get_signature_keys(obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default is not True}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default is True})
|
||||
expected_modules = set(required_parameters.keys()) - set(["self"])
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
@property
|
||||
def components(self) -> Dict[str, Any]:
|
||||
r"""
|
||||
|
@ -688,8 +709,10 @@ class DiffusionPipeline(ConfigMixin):
|
|||
Returns:
|
||||
A dictionaly containing all the modules needed to initialize the pipeline.
|
||||
"""
|
||||
components = {k: getattr(self, k) for k in self.config.keys() if not k.startswith("_")}
|
||||
expected_modules = set(inspect.signature(self.__init__).parameters.keys()) - set(["self"])
|
||||
expected_modules, optional_parameters = self._get_signature_keys(self)
|
||||
components = {
|
||||
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
|
||||
}
|
||||
|
||||
if set(components.keys()) != expected_modules:
|
||||
raise ValueError(
|
||||
|
|
|
@ -67,6 +67,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
|||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -84,6 +85,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
|||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -114,7 +116,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
|||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
|
||||
|
@ -124,6 +126,12 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -133,6 +141,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
|
|
|
@ -80,6 +80,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -97,6 +98,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -127,7 +129,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Alt Diffusion license and do not expose unfiltered"
|
||||
|
@ -137,6 +139,12 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -146,6 +154,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
|
|
|
@ -132,6 +132,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
|||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -142,6 +143,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
|||
scheduler: DDIMScheduler,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -159,7 +161,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
|||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
|
@ -169,6 +171,12 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -178,6 +186,7 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
|
|
|
@ -51,6 +51,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
|||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -81,6 +82,22 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
|||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
|
@ -91,6 +108,7 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
r"""
|
||||
|
|
|
@ -87,6 +87,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -117,7 +118,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
|
@ -127,6 +128,12 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
|
@ -137,6 +144,7 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
|
|
|
@ -100,6 +100,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
logger.info("`OnnxStableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
|
||||
|
@ -131,7 +132,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
|
@ -141,6 +142,12 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
|
@ -151,6 +158,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
|
|
|
@ -86,6 +86,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: OnnxRuntimeModel,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -116,7 +117,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
|
@ -126,6 +127,12 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae_encoder=vae_encoder,
|
||||
vae_decoder=vae_decoder,
|
||||
|
@ -136,6 +143,7 @@ class OnnxStableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
|
||||
|
|
|
@ -66,6 +66,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -83,6 +84,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -113,7 +115,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
|
@ -123,6 +125,12 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -132,6 +140,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
|
|
|
@ -63,6 +63,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -79,10 +80,11 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warn(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
|
@ -92,6 +94,12 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
image_encoder=image_encoder,
|
||||
|
@ -100,6 +108,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_xformers_memory_efficient_attention
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
|
|
|
@ -78,6 +78,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
|
||||
def __init__(
|
||||
|
@ -96,6 +97,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -126,7 +128,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
|
@ -136,6 +138,12 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -145,6 +153,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
|
|
|
@ -150,6 +150,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -160,6 +161,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -191,7 +193,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
new_config["skip_prk_steps"] = True
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
|
@ -201,6 +203,12 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -210,6 +218,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
|
|
|
@ -91,6 +91,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__
|
||||
def __init__(
|
||||
|
@ -109,6 +110,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -139,7 +141,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
|
@ -149,6 +151,12 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -158,6 +166,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
|
|
|
@ -56,6 +56,8 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
|||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
"""
|
||||
|
||||
_optional_components = ["safety_checker", "feature_extractor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
|
@ -72,6 +74,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
|||
],
|
||||
safety_checker: SafeStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
safety_concept: Optional[str] = (
|
||||
|
@ -107,7 +110,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
|||
new_config["clip_sample"] = False
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
|
@ -117,6 +120,12 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
|||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
if safety_checker is not None and feature_extractor is None:
|
||||
raise ValueError(
|
||||
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
||||
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -127,6 +136,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
|||
feature_extractor=feature_extractor,
|
||||
)
|
||||
self._safety_text_concept = safety_concept
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
@property
|
||||
def safety_concept(self):
|
||||
|
|
|
@ -14,8 +14,10 @@
|
|||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
@ -40,7 +42,6 @@ from diffusers import (
|
|||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
VQModel,
|
||||
logging,
|
||||
)
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
|
@ -284,32 +285,7 @@ class PipelineFastTests(unittest.TestCase):
|
|||
)
|
||||
return model
|
||||
|
||||
def dummy_cond_unet_inpaint(self, sample_size=32):
|
||||
torch.manual_seed(0)
|
||||
model = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=sample_size,
|
||||
in_channels=9,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
cross_attention_dim=32,
|
||||
)
|
||||
return model
|
||||
|
||||
def dummy_vq_model(self):
|
||||
torch.manual_seed(0)
|
||||
model = VQModel(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=3,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_vae(self):
|
||||
torch.manual_seed(0)
|
||||
model = AutoencoderKL(
|
||||
|
@ -322,6 +298,7 @@ class PipelineFastTests(unittest.TestCase):
|
|||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
|
@ -337,6 +314,7 @@ class PipelineFastTests(unittest.TestCase):
|
|||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
@property
|
||||
def dummy_extractor(self):
|
||||
def extract(*args, **kwargs):
|
||||
class Out:
|
||||
|
@ -383,8 +361,8 @@ class PipelineFastTests(unittest.TestCase):
|
|||
"""Test that components property works correctly"""
|
||||
unet = self.dummy_cond_unet()
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae()
|
||||
bert = self.dummy_text_encoder()
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
image = self.dummy_image().cpu().permute(0, 2, 3, 1)[0]
|
||||
|
@ -399,7 +377,7 @@ class PipelineFastTests(unittest.TestCase):
|
|||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=None,
|
||||
feature_extractor=self.dummy_extractor(),
|
||||
feature_extractor=self.dummy_extractor,
|
||||
).to(torch_device)
|
||||
img2img = StableDiffusionImg2ImgPipeline(**inpaint.components).to(torch_device)
|
||||
text2img = StableDiffusionPipeline(**inpaint.components).to(torch_device)
|
||||
|
@ -439,7 +417,7 @@ class PipelineFastTests(unittest.TestCase):
|
|||
assert image_text2img.shape == (1, 64, 64, 3)
|
||||
|
||||
def test_set_scheduler(self):
|
||||
unet = self.dummy_cond_unet
|
||||
unet = self.dummy_cond_unet()
|
||||
scheduler = PNDMScheduler(skip_prk_steps=True)
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
|
@ -471,7 +449,7 @@ class PipelineFastTests(unittest.TestCase):
|
|||
assert isinstance(sd.scheduler, DPMSolverMultistepScheduler)
|
||||
|
||||
def test_set_scheduler_consistency(self):
|
||||
unet = self.dummy_cond_unet
|
||||
unet = self.dummy_cond_unet()
|
||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
vae = self.dummy_vae
|
||||
|
@ -514,6 +492,110 @@ class PipelineFastTests(unittest.TestCase):
|
|||
|
||||
assert dict(ddim_config) == dict(ddim_config_2)
|
||||
|
||||
def test_optional_components(self):
|
||||
unet = self.dummy_cond_unet()
|
||||
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
|
||||
vae = self.dummy_vae
|
||||
bert = self.dummy_text_encoder
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
orig_sd = StableDiffusionPipeline(
|
||||
unet=unet,
|
||||
scheduler=pndm,
|
||||
vae=vae,
|
||||
text_encoder=bert,
|
||||
tokenizer=tokenizer,
|
||||
safety_checker=unet,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
)
|
||||
sd = orig_sd
|
||||
|
||||
assert sd.config.requires_safety_checker is True
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd.save_pretrained(tmpdirname)
|
||||
|
||||
# Test that passing None works
|
||||
sd = StableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname, feature_extractor=None, safety_checker=None, requires_safety_checker=False
|
||||
)
|
||||
|
||||
assert sd.config.requires_safety_checker is False
|
||||
assert sd.config.safety_checker == (None, None)
|
||||
assert sd.config.feature_extractor == (None, None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd.save_pretrained(tmpdirname)
|
||||
|
||||
# Test that loading previous None works
|
||||
sd = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
assert sd.config.requires_safety_checker is False
|
||||
assert sd.config.safety_checker == (None, None)
|
||||
assert sd.config.feature_extractor == (None, None)
|
||||
|
||||
orig_sd.save_pretrained(tmpdirname)
|
||||
|
||||
# Test that loading without any directory works
|
||||
shutil.rmtree(os.path.join(tmpdirname, "safety_checker"))
|
||||
with open(os.path.join(tmpdirname, sd.config_name)) as f:
|
||||
config = json.load(f)
|
||||
config["safety_checker"] = [None, None]
|
||||
with open(os.path.join(tmpdirname, sd.config_name), "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
sd = StableDiffusionPipeline.from_pretrained(tmpdirname, requires_safety_checker=False)
|
||||
sd.save_pretrained(tmpdirname)
|
||||
sd = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
assert sd.config.requires_safety_checker is False
|
||||
assert sd.config.safety_checker == (None, None)
|
||||
assert sd.config.feature_extractor == (None, None)
|
||||
|
||||
# Test that loading from deleted model index works
|
||||
with open(os.path.join(tmpdirname, sd.config_name)) as f:
|
||||
config = json.load(f)
|
||||
del config["safety_checker"]
|
||||
del config["feature_extractor"]
|
||||
with open(os.path.join(tmpdirname, sd.config_name), "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
sd = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
assert sd.config.requires_safety_checker is False
|
||||
assert sd.config.safety_checker == (None, None)
|
||||
assert sd.config.feature_extractor == (None, None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd.save_pretrained(tmpdirname)
|
||||
|
||||
# Test that partially loading works
|
||||
sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor)
|
||||
|
||||
assert sd.config.requires_safety_checker is False
|
||||
assert sd.config.safety_checker == (None, None)
|
||||
assert sd.config.feature_extractor != (None, None)
|
||||
|
||||
# Test that partially loading works
|
||||
sd = StableDiffusionPipeline.from_pretrained(
|
||||
tmpdirname,
|
||||
feature_extractor=self.dummy_extractor,
|
||||
safety_checker=unet,
|
||||
requires_safety_checker=[True, True],
|
||||
)
|
||||
|
||||
assert sd.config.requires_safety_checker == [True, True]
|
||||
assert sd.config.safety_checker != (None, None)
|
||||
assert sd.config.feature_extractor != (None, None)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd.save_pretrained(tmpdirname)
|
||||
sd = StableDiffusionPipeline.from_pretrained(tmpdirname, feature_extractor=self.dummy_extractor)
|
||||
|
||||
assert sd.config.requires_safety_checker == [True, True]
|
||||
assert sd.config.safety_checker != (None, None)
|
||||
assert sd.config.feature_extractor != (None, None)
|
||||
|
||||
|
||||
@slow
|
||||
class PipelineSlowTests(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue