Give more customizable options for safety checker (#815)
* Give more customizable options for safety checker * Apply suggestions from code review * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * Finish * make style * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
26c7df5d82
commit
e713346ad1
|
@ -113,23 +113,26 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
for name, module in kwargs.items():
|
for name, module in kwargs.items():
|
||||||
# retrieve library
|
# retrieve library
|
||||||
library = module.__module__.split(".")[0]
|
if module is None:
|
||||||
|
register_dict = {name: (None, None)}
|
||||||
|
else:
|
||||||
|
library = module.__module__.split(".")[0]
|
||||||
|
|
||||||
# check if the module is a pipeline module
|
# check if the module is a pipeline module
|
||||||
pipeline_dir = module.__module__.split(".")[-2]
|
pipeline_dir = module.__module__.split(".")[-2]
|
||||||
path = module.__module__.split(".")
|
path = module.__module__.split(".")
|
||||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||||
|
|
||||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||||
# folder so we set the library to module name.
|
# folder so we set the library to module name.
|
||||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||||
library = pipeline_dir
|
library = pipeline_dir
|
||||||
|
|
||||||
# retrieve class_name
|
# retrieve class_name
|
||||||
class_name = module.__class__.__name__
|
class_name = module.__class__.__name__
|
||||||
|
|
||||||
register_dict = {name: (library, class_name)}
|
register_dict = {name: (library, class_name)}
|
||||||
|
|
||||||
# save model index config
|
# save model index config
|
||||||
self.register_to_config(**register_dict)
|
self.register_to_config(**register_dict)
|
||||||
|
@ -429,6 +432,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
loaded_sub_model = None
|
loaded_sub_model = None
|
||||||
|
sub_model_should_be_defined = True
|
||||||
|
|
||||||
# if the model is in a pipeline module, then we load it from the pipeline
|
# if the model is in a pipeline module, then we load it from the pipeline
|
||||||
if name in passed_class_obj:
|
if name in passed_class_obj:
|
||||||
|
@ -449,6 +453,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||||
f" {expected_class_obj}"
|
f" {expected_class_obj}"
|
||||||
)
|
)
|
||||||
|
elif passed_class_obj[name] is None:
|
||||||
|
logger.warn(
|
||||||
|
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
|
||||||
|
f" that this might lead to problems when using {pipeline_class} and is not recommended."
|
||||||
|
)
|
||||||
|
sub_model_should_be_defined = False
|
||||||
else:
|
else:
|
||||||
logger.warn(
|
logger.warn(
|
||||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||||
|
@ -469,7 +479,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
importable_classes = LOADABLE_CLASSES[library_name]
|
importable_classes = LOADABLE_CLASSES[library_name]
|
||||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||||
|
|
||||||
if loaded_sub_model is None:
|
if loaded_sub_model is None and sub_model_should_be_defined:
|
||||||
load_method_name = None
|
load_method_name = None
|
||||||
for class_name, class_candidate in class_candidates.items():
|
for class_name, class_candidate in class_candidates.items():
|
||||||
if issubclass(class_obj, class_candidate):
|
if issubclass(class_obj, class_candidate):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -20,11 +20,11 @@ class StableDiffusionPipelineOutput(BaseOutput):
|
||||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||||
nsfw_content_detected (`List[bool]`)
|
nsfw_content_detected (`List[bool]`)
|
||||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||||
(nsfw) content.
|
(nsfw) content, or `None` if safety checking could not be performed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||||
nsfw_content_detected: List[bool]
|
nsfw_content_detected: Optional[List[bool]]
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_available() and is_torch_available():
|
if is_transformers_available() and is_torch_available():
|
||||||
|
|
|
@ -71,6 +71,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
new_config["steps_offset"] = 1
|
new_config["steps_offset"] = 1
|
||||||
scheduler._internal_dict = FrozenDict(new_config)
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if safety_checker is None:
|
||||||
|
logger.warn(
|
||||||
|
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||||
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||||
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||||
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||||
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||||
|
)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
@ -335,10 +345,15 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||||
|
|
||||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
if self.safety_checker is not None:
|
||||||
image, has_nsfw_concept = self.safety_checker(
|
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
self.device
|
||||||
)
|
)
|
||||||
|
image, has_nsfw_concept = self.safety_checker(
|
||||||
|
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
has_nsfw_concept = None
|
||||||
|
|
||||||
if output_type == "pil":
|
if output_type == "pil":
|
||||||
image = self.numpy_to_pil(image)
|
image = self.numpy_to_pil(image)
|
||||||
|
|
|
@ -83,6 +83,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
new_config["steps_offset"] = 1
|
new_config["steps_offset"] = 1
|
||||||
scheduler._internal_dict = FrozenDict(new_config)
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if safety_checker is None:
|
||||||
|
logger.warn(
|
||||||
|
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||||
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||||
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||||
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||||
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||||
|
)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
@ -359,10 +369,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
|
||||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
if self.safety_checker is not None:
|
||||||
image, has_nsfw_concept = self.safety_checker(
|
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
self.device
|
||||||
)
|
)
|
||||||
|
image, has_nsfw_concept = self.safety_checker(
|
||||||
|
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
has_nsfw_concept = None
|
||||||
|
|
||||||
if output_type == "pil":
|
if output_type == "pil":
|
||||||
image = self.numpy_to_pil(image)
|
image = self.numpy_to_pil(image)
|
||||||
|
|
|
@ -98,6 +98,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
new_config["steps_offset"] = 1
|
new_config["steps_offset"] = 1
|
||||||
scheduler._internal_dict = FrozenDict(new_config)
|
scheduler._internal_dict = FrozenDict(new_config)
|
||||||
|
|
||||||
|
if safety_checker is None:
|
||||||
|
logger.warn(
|
||||||
|
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||||
|
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||||
|
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||||
|
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||||
|
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||||
|
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||||
|
)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
@ -382,8 +392,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||||
image = (image / 2 + 0.5).clamp(0, 1)
|
image = (image / 2 + 0.5).clamp(0, 1)
|
||||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
|
|
||||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
if self.safety_checker is not None:
|
||||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||||
|
self.device
|
||||||
|
)
|
||||||
|
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||||
|
else:
|
||||||
|
has_nsfw_concept = None
|
||||||
|
|
||||||
if output_type == "pil":
|
if output_type == "pil":
|
||||||
image = self.numpy_to_pil(image)
|
image = self.numpy_to_pil(image)
|
||||||
|
|
|
@ -498,6 +498,17 @@ class PipelineFastTests(unittest.TestCase):
|
||||||
assert isinstance(pipe, StableDiffusionPipeline)
|
assert isinstance(pipe, StableDiffusionPipeline)
|
||||||
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
|
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
|
||||||
|
|
||||||
|
def test_stable_diffusion_no_safety_checker(self):
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
|
||||||
|
)
|
||||||
|
assert isinstance(pipe, StableDiffusionPipeline)
|
||||||
|
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
|
||||||
|
assert pipe.safety_checker is None
|
||||||
|
|
||||||
|
image = pipe("example prompt", num_inference_steps=2).images[0]
|
||||||
|
assert image is not None
|
||||||
|
|
||||||
def test_stable_diffusion_k_lms(self):
|
def test_stable_diffusion_k_lms(self):
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
unet = self.dummy_cond_unet
|
unet = self.dummy_cond_unet
|
||||||
|
|
Loading…
Reference in New Issue