diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 6cd67829..7e58d048 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -291,7 +291,8 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): - images[i] = np.asarray(images_uint8_casted[i]) + if is_nsfw: + images[i] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: