Give more customizable options for safety checker (#815)

* Give more customizable options for safety checker

* Apply suggestions from code review

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

* Finish

* make style

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* up

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Patrick von Platen 2022-10-13 15:52:26 +02:00 committed by GitHub
parent 26c7df5d82
commit e713346ad1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 93 additions and 27 deletions

View File

@ -113,23 +113,26 @@ class DiffusionPipeline(ConfigMixin):
for name, module in kwargs.items(): for name, module in kwargs.items():
# retrieve library # retrieve library
library = module.__module__.split(".")[0] if module is None:
register_dict = {name: (None, None)}
else:
library = module.__module__.split(".")[0]
# check if the module is a pipeline module # check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2] pipeline_dir = module.__module__.split(".")[-2]
path = module.__module__.split(".") path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module. # if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline # Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name. # folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module: if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir library = pipeline_dir
# retrieve class_name # retrieve class_name
class_name = module.__class__.__name__ class_name = module.__class__.__name__
register_dict = {name: (library, class_name)} register_dict = {name: (library, class_name)}
# save model index config # save model index config
self.register_to_config(**register_dict) self.register_to_config(**register_dict)
@ -429,6 +432,7 @@ class DiffusionPipeline(ConfigMixin):
is_pipeline_module = hasattr(pipelines, library_name) is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None loaded_sub_model = None
sub_model_should_be_defined = True
# if the model is in a pipeline module, then we load it from the pipeline # if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj: if name in passed_class_obj:
@ -449,6 +453,12 @@ class DiffusionPipeline(ConfigMixin):
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}" f" {expected_class_obj}"
) )
elif passed_class_obj[name] is None:
logger.warn(
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
f" that this might lead to problems when using {pipeline_class} and is not recommended."
)
sub_model_should_be_defined = False
else: else:
logger.warn( logger.warn(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
@ -469,7 +479,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = LOADABLE_CLASSES[library_name] importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
if loaded_sub_model is None: if loaded_sub_model is None and sub_model_should_be_defined:
load_method_name = None load_method_name = None
for class_name, class_candidate in class_candidates.items(): for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate): if issubclass(class_obj, class_candidate):

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Union from typing import List, Optional, Union
import numpy as np import numpy as np
@ -20,11 +20,11 @@ class StableDiffusionPipelineOutput(BaseOutput):
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`) nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content. (nsfw) content, or `None` if safety checking could not be performed.
""" """
images: Union[List[PIL.Image.Image], np.ndarray] images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool] nsfw_content_detected: Optional[List[bool]]
if is_transformers_available() and is_torch_available(): if is_transformers_available() and is_torch_available():

View File

@ -71,6 +71,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
@ -335,10 +345,15 @@ class StableDiffusionPipeline(DiffusionPipeline):
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = image.cpu().permute(0, 2, 3, 1).float().numpy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) if self.safety_checker is not None:
image, has_nsfw_concept = self.safety_checker( safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) self.device
) )
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)

View File

@ -83,6 +83,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
@ -359,10 +369,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) if self.safety_checker is not None:
image, has_nsfw_concept = self.safety_checker( safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) self.device
) )
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
)
else:
has_nsfw_concept = None
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)

View File

@ -98,6 +98,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
new_config["steps_offset"] = 1 new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config) scheduler._internal_dict = FrozenDict(new_config)
if safety_checker is None:
logger.warn(
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
text_encoder=text_encoder, text_encoder=text_encoder,
@ -382,8 +392,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
image = (image / 2 + 0.5).clamp(0, 1) image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy() image = image.cpu().permute(0, 2, 3, 1).numpy()
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) if self.safety_checker is not None:
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
self.device
)
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
else:
has_nsfw_concept = None
if output_type == "pil": if output_type == "pil":
image = self.numpy_to_pil(image) image = self.numpy_to_pil(image)

View File

@ -498,6 +498,17 @@ class PipelineFastTests(unittest.TestCase):
assert isinstance(pipe, StableDiffusionPipeline) assert isinstance(pipe, StableDiffusionPipeline)
assert isinstance(pipe.scheduler, LMSDiscreteScheduler) assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
def test_stable_diffusion_no_safety_checker(self):
pipe = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
)
assert isinstance(pipe, StableDiffusionPipeline)
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
assert pipe.safety_checker is None
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None
def test_stable_diffusion_k_lms(self): def test_stable_diffusion_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet unet = self.dummy_cond_unet