refactor: pipelines readability improvements (#622)
* refactor: pipelines readability improvements Signed-off-by: Ryan Russell <git@ryanrussell.org> * docs: remove todo comment from flax pipeline Signed-off-by: Ryan Russell <git@ryanrussell.org> Signed-off-by: Ryan Russell <git@ryanrussell.org>
This commit is contained in:
parent
b00382e2a7
commit
ce31f83d8c
|
@ -34,7 +34,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`FlaxDDIMScheduler`], [`FlaxLMSDiscreteScheduler`], or [`FlaxPNDMScheduler`].
|
||||
safety_checker ([`FlaxStableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
|
@ -149,7 +149,6 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
uncond_embeddings = self.text_encoder(uncond_input.input_ids, params=params["text_encoder"])[0]
|
||||
context = jnp.concatenate([uncond_embeddings, text_embeddings])
|
||||
|
||||
# TODO: check it because the shape is different from Pytorhc StableDiffusionPipeline
|
||||
latents_shape = (
|
||||
batch_size,
|
||||
self.unet.in_channels,
|
||||
|
@ -206,9 +205,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
|||
# image = jnp.asarray(image).transpose(0, 2, 3, 1)
|
||||
# run safety checker
|
||||
# TODO: check when flax safety checker gets merged into main
|
||||
# safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
||||
# safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="np")
|
||||
# image, has_nsfw_concept = self.safety_checker(
|
||||
# images=image, clip_input=safety_cheker_input.pixel_values, params=params["safety_params"]
|
||||
# images=image, clip_input=safety_checker_input.pixel_values, params=params["safety_params"]
|
||||
# )
|
||||
has_nsfw_concept = False
|
||||
|
||||
|
|
|
@ -36,7 +36,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
|
@ -278,8 +278,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
|
|
@ -48,7 +48,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
|
@ -288,8 +288,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
|
|
@ -66,7 +66,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
|
||||
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||
Classification module that estimates whether generated images could be considered offsensive or harmful.
|
||||
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||
Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
|
||||
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||
|
@ -328,8 +328,8 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
# run safety checker
|
||||
safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values)
|
||||
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device)
|
||||
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values)
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
|
|
@ -48,20 +48,20 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
|
|||
# at the cost of increasing the possibility of filtering benign images
|
||||
adjustment = 0.0
|
||||
|
||||
for concet_idx in range(len(special_cos_dist[0])):
|
||||
concept_cos = special_cos_dist[i][concet_idx]
|
||||
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
|
||||
result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["special_scores"][concet_idx] > 0:
|
||||
result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]})
|
||||
for concept_idx in range(len(special_cos_dist[0])):
|
||||
concept_cos = special_cos_dist[i][concept_idx]
|
||||
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
|
||||
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["special_scores"][concept_idx] > 0:
|
||||
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
|
||||
adjustment = 0.01
|
||||
|
||||
for concet_idx in range(len(cos_dist[0])):
|
||||
concept_cos = cos_dist[i][concet_idx]
|
||||
concept_threshold = self.concept_embeds_weights[concet_idx].item()
|
||||
result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["concept_scores"][concet_idx] > 0:
|
||||
result_img["bad_concepts"].append(concet_idx)
|
||||
for concept_idx in range(len(cos_dist[0])):
|
||||
concept_cos = cos_dist[i][concept_idx]
|
||||
concept_threshold = self.concept_embeds_weights[concept_idx].item()
|
||||
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
|
||||
if result_img["concept_scores"][concept_idx] > 0:
|
||||
result_img["bad_concepts"].append(concept_idx)
|
||||
|
||||
result.append(result_img)
|
||||
|
||||
|
|
Loading…
Reference in New Issue