From ce31f83d8c44c977dcf8413c99b8a6cb7f189c07 Mon Sep 17 00:00:00 2001 From: Ryan Russell Date: Fri, 23 Sep 2022 08:02:12 -0500 Subject: [PATCH] refactor: pipelines readability improvements (#622) * refactor: pipelines readability improvements Signed-off-by: Ryan Russell * docs: remove todo comment from flax pipeline Signed-off-by: Ryan Russell Signed-off-by: Ryan Russell --- .../pipeline_flax_stable_diffusion.py | 7 +++--- .../pipeline_stable_diffusion.py | 6 ++--- .../pipeline_stable_diffusion_img2img.py | 6 ++--- .../pipeline_stable_diffusion_inpaint.py | 6 ++--- .../stable_diffusion/safety_checker.py | 24 +++++++++---------- 5 files changed, 24 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 675b6126..870a715e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -34,7 +34,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`]. safety_checker ([`FlaxStableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. + Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. @@ -149,7 +149,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0] context = jnp.concatenate([uncond_embeddings, text_embeddings]) - # TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline latents_shape = ( batch_size, self.unet.in_channels, @@ -206,9 +205,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): # image = jnp.asarray(image).transpose(0, 2, 3, 1) # run safety checker # TODO: check when flax safety checker gets merged into main - # safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") + # safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np") # image, has_nsfw_concept = self.safety_checker( - # images=image, clip_input=safety_cheker_input.pixel_values, params=params["safety_params"] + # images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"] # ) has_nsfw_concept = False diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 216a76a5..411f2730 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -36,7 +36,7 @@ class StableDiffusionPipeline(DiffusionPipeline): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. + Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. @@ -278,8 +278,8 @@ class StableDiffusionPipeline(DiffusionPipeline): image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index e2affac6..46299bf3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -48,7 +48,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. + Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. @@ -288,8 +288,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 8d18d2f3..7de7925a 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -66,7 +66,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offsensive or harmful. + Classification module that estimates whether generated images could be considered offensive or harmful. Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. feature_extractor ([`CLIPFeatureExtractor`]): Model that extracts features from generated images to be used as inputs for the `safety_checker`. @@ -328,8 +328,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): image = image.cpu().permute(0, 2, 3, 1).numpy() # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 09de92ee..3eb8828c 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -48,20 +48,20 @@ class StableDiffusionSafetyChecker(PreTrainedModel): # at the cost of increasing the possibility of filtering benign images adjustment = 0.0 - for concet_idx in range(len(special_cos_dist[0])): - concept_cos = special_cos_dist[i][concet_idx] - concept_threshold = self.special_care_embeds_weights[concet_idx].item() - result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["special_scores"][concet_idx] > 0: - result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) adjustment = 0.01 - for concet_idx in range(len(cos_dist[0])): - concept_cos = cos_dist[i][concet_idx] - concept_threshold = self.concept_embeds_weights[concet_idx].item() - result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["concept_scores"][concet_idx] > 0: - result_img["bad_concepts"].append(concet_idx) + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) result.append(result_img)