diff --git a/examples/inference/readme.md b/examples/inference/readme.md index e61004e4..02391243 100644 --- a/examples/inference/readme.md +++ b/examples/inference/readme.md @@ -47,4 +47,8 @@ with autocast("cuda"): images[0].save("fantasy_landscape.png") ``` -You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) \ No newline at end of file +You can also run this example on colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/patil-suraj/Notebooks/blob/master/image_2_image_using_diffusers.ipynb) + +## Tweak prompts reusing seeds and latents + +You can generate your own latents to reproduce results, or tweak your prompt on a specific result you liked. [This notebook](stable-diffusion-seeds.ipynb) shows how to do it step by step. You can also run it in Google Colab [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/pcuenca/diffusers-examples/blob/main/notebooks/stable-diffusion-seeds.ipynb). diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 550513b5..f0b353d9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -46,6 +46,7 @@ class StableDiffusionPipeline(DiffusionPipeline): guidance_scale: Optional[float] = 7.5, eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, + latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", **kwargs, ): @@ -98,12 +99,18 @@ class StableDiffusionPipeline(DiffusionPipeline): # to avoid doing two forward passes text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - # get the intial random noise - latents = torch.randn( - (batch_size, self.unet.in_channels, height // 8, width // 8), - generator=generator, - device=self.device, - ) + # get the initial random noise unless the user supplied it + latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8) + if latents is None: + latents = torch.randn( + latents_shape, + generator=generator, + device=self.device, + ) + else: + if latents.shape != latents_shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + latents = latents.to(self.device) # set timesteps accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())