Give more customizable options for safety checker (#815)
* Give more customizable options for safety checker * Apply suggestions from code review * Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py * Finish * make style * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * up Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
26c7df5d82
commit
e713346ad1
|
@ -113,23 +113,26 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
for name, module in kwargs.items():
|
||||
# retrieve library
|
||||
library = module.__module__.split(".")[0]
|
||||
if module is None:
|
||||
register_dict = {name: (None, None)}
|
||||
else:
|
||||
library = module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
pipeline_dir = module.__module__.split(".")[-2]
|
||||
path = module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
# check if the module is a pipeline module
|
||||
pipeline_dir = module.__module__.split(".")[-2]
|
||||
path = module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
|
||||
# retrieve class_name
|
||||
class_name = module.__class__.__name__
|
||||
# retrieve class_name
|
||||
class_name = module.__class__.__name__
|
||||
|
||||
register_dict = {name: (library, class_name)}
|
||||
register_dict = {name: (library, class_name)}
|
||||
|
||||
# save model index config
|
||||
self.register_to_config(**register_dict)
|
||||
|
@ -429,6 +432,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
sub_model_should_be_defined = True
|
||||
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
if name in passed_class_obj:
|
||||
|
@ -449,6 +453,12 @@ class DiffusionPipeline(ConfigMixin):
|
|||
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
|
||||
f" {expected_class_obj}"
|
||||
)
|
||||
elif passed_class_obj[name] is None:
|
||||
logger.warn(
|
||||
f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note"
|
||||
f" that this might lead to problems when using {pipeline_class} and is not recommended."
|
||||
)
|
||||
sub_model_should_be_defined = False
|
||||
else:
|
||||
logger.warn(
|
||||
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
|
||||
|
@ -469,7 +479,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
|
||||
|
||||
if loaded_sub_model is None:
|
||||
if loaded_sub_model is None and sub_model_should_be_defined:
|
||||
load_method_name = None
|
||||
for class_name, class_candidate in class_candidates.items():
|
||||
if issubclass(class_obj, class_candidate):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -20,11 +20,11 @@ class StableDiffusionPipelineOutput(BaseOutput):
|
|||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content.
|
||||
(nsfw) content, or `None` if safety checking could not be performed.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_content_detected: List[bool]
|
||||
nsfw_content_detected: Optional[List[bool]]
|
||||
|
||||
|
||||
if is_transformers_available() and is_torch_available():
|
||||
|
|
|
@ -71,6 +71,16 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -335,10 +345,15 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
|
|
@ -83,6 +83,16 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -359,10 +369,15 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
|
||||
)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
|
|
@ -98,6 +98,16 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
new_config["steps_offset"] = 1
|
||||
scheduler._internal_dict = FrozenDict(new_config)
|
||||
|
||||
if safety_checker is None:
|
||||
logger.warn(
|
||||
f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
||||
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
||||
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
||||
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
||||
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
||||
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
|
||||
)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
|
@ -382,8 +392,13 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
if self.safety_checker is not None:
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
else:
|
||||
has_nsfw_concept = None
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
|
|
@ -498,6 +498,17 @@ class PipelineFastTests(unittest.TestCase):
|
|||
assert isinstance(pipe, StableDiffusionPipeline)
|
||||
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
|
||||
|
||||
def test_stable_diffusion_no_safety_checker(self):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None
|
||||
)
|
||||
assert isinstance(pipe, StableDiffusionPipeline)
|
||||
assert isinstance(pipe.scheduler, LMSDiscreteScheduler)
|
||||
assert pipe.safety_checker is None
|
||||
|
||||
image = pipe("example prompt", num_inference_steps=2).images[0]
|
||||
assert image is not None
|
||||
|
||||
def test_stable_diffusion_k_lms(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
|
|
Loading…
Reference in New Issue