diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 94a6c67b..5b781f0e 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -42,6 +42,7 @@ LOADABLE_CLASSES = { "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], }, } @@ -63,9 +64,9 @@ class DiffusionPipeline(ConfigMixin): library = module.__module__.split(".")[0] # check if the module is a pipeline module - pipeline_file = module.__module__.split(".")[-1] pipeline_dir = module.__module__.split(".")[-2] - is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir) + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 5e48f6f5..5306ba82 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -3,4 +3,4 @@ from ...utils import is_transformers_available if is_transformers_available(): - from .pipeline_stable_diffusion import StableDiffusionPipeline + from .pipeline_stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 3b4acd46..baff1db9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -4,11 +4,12 @@ from typing import List, Optional, Union import torch from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from .safety_checker import StableDiffusionSafetyChecker class StableDiffusionPipeline(DiffusionPipeline): @@ -19,10 +20,20 @@ class StableDiffusionPipeline(DiffusionPipeline): tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, ): super().__init__() scheduler = scheduler.set_format("pt") - self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) @torch.no_grad() def __call__( @@ -53,6 +64,7 @@ class StableDiffusionPipeline(DiffusionPipeline): self.unet.to(torch_device) self.vae.to(torch_device) self.text_encoder.to(torch_device) + self.safety_checker.to(torch_device) # get prompt text embeddings text_input = self.tokenizer( @@ -136,7 +148,12 @@ class StableDiffusionPipeline(DiffusionPipeline): image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() + + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image} + return {"sample": image, "nsfw_content_detected": has_nsfw_concept} diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py new file mode 100644 index 00000000..1c5db421 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -0,0 +1,77 @@ +import numpy as np +import torch +import torch.nn as nn + +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel + +from ...utils import logging + + +logger = logging.get_logger(__name__) + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.T) + + +class StableDiffusionSafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.register_buffer("concept_embeds_weights", torch.ones(17)) + self.register_buffer("special_care_embeds_weights", torch.ones(3)) + + @torch.no_grad() + def forward(self, clip_input, images): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() + + result = [] + batch_size = image_embeds.shape[0] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + adjustment = 0.05 + + for concet_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concet_idx] + concept_threshold = self.special_care_embeds_weights[concet_idx].item() + result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concet_idx] > 0: + result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) + adjustment = 0.01 + + for concet_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concet_idx] + concept_threshold = self.concept_embeds_weights[concet_idx].item() + result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concet_idx] > 0: + result_img["bad_concepts"].append(concet_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(result[i]["bad_concepts"]) > 0 or i in range(len(result))] + + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts