diff --git a/README.md b/README.md index 28e1be16..6fc01795 100644 --- a/README.md +++ b/README.md @@ -80,7 +80,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt)["sample"][0] + image = pipe(prompt).images[0] ``` **Note**: If you don't want to use the token, you can also simply download the model weights @@ -101,7 +101,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt)["sample"][0] + image = pipe(prompt).images[0] ``` If you are limited by GPU memory, you might want to consider using the model in `fp16`. @@ -117,7 +117,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt)["sample"][0] + image = pipe(prompt).images[0] ``` Finally, if you wish to use a different scheduler, you can simply instantiate @@ -143,7 +143,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt)["sample"][0] + image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -184,7 +184,7 @@ init_image = init_image.resize((768, 512)) prompt = "A fantasy landscape, trending on artstation" with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"] + images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images images[0].save("fantasy_landscape.png") ``` @@ -228,7 +228,7 @@ pipe = pipe.to(device) prompt = "a cat sitting on a bench" with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"] + images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images images[0].save("cat_on_bench.png") ``` @@ -260,7 +260,7 @@ ldm = DiffusionPipeline.from_pretrained(model_id) # run pipeline in inference (sample random noise and denoise) prompt = "A painting of a squirrel eating a burger" -images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"] +images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images # save images for idx, image in enumerate(images): @@ -277,7 +277,7 @@ model_id = "google/ddpm-celebahq-256" ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference # run pipeline in inference (sample random noise and denoise) -image = ddpm()["sample"] +image = ddpm().images # save image image[0].save("ddpm_generated_image.png") diff --git a/examples/textual_inversion/README.md b/examples/textual_inversion/README.md index 17dfe014..74ebcd51 100644 --- a/examples/textual_inversion/README.md +++ b/examples/textual_inversion/README.md @@ -76,7 +76,7 @@ pipe = pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch prompt = "A backpack" with autocast("cuda"): - image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)["sample"][0] + image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] image.save("cat-backpack.png") ``` diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index fe19e0ea..dd8627e3 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -498,7 +498,7 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).sample().detach() + latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents @@ -515,7 +515,7 @@ def main(): encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"] + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 5edb2457..fe4e9b0d 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -139,7 +139,7 @@ def main(args): with accelerator.accumulate(model): # Predict the noise residual - noise_pred = model(noisy_images, timesteps)["sample"] + noise_pred = model(noisy_images, timesteps).sample loss = F.mse_loss(noise_pred, noise) accelerator.backward(loss) @@ -174,7 +174,7 @@ def main(args): generator = torch.manual_seed(0) # run pipeline in inference (sample random noise and denoise) - images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"] + images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images # denormalize the images and save to tensorboard images_processed = (images * 255).round().astype("uint8") diff --git a/scripts/generate_logits.py b/scripts/generate_logits.py index 61851212..47dc5485 100644 --- a/scripts/generate_logits.py +++ b/scripts/generate_logits.py @@ -119,7 +119,7 @@ for mod in models: noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) time_step = torch.tensor([10] * noise.shape[0]) with torch.no_grad(): - logits = model(noise, time_step)["sample"] + logits = model(noise, time_step).sample assert torch.allclose( logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3 diff --git a/src/diffusers/hub_utils.py b/src/diffusers/hub_utils.py index 0af2033a..c07329e3 100644 --- a/src/diffusers/hub_utils.py +++ b/src/diffusers/hub_utils.py @@ -19,9 +19,9 @@ import shutil from pathlib import Path from typing import Optional -from diffusers import DiffusionPipeline from huggingface_hub import HfFolder, Repository, whoami +from .pipeline_utils import DiffusionPipeline from .utils import is_modelcards_available, logging diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 86344563..46d5ee53 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -1,14 +1,27 @@ -from typing import Dict, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin +from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block +@dataclass +class UNet2DOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states output. Output of last layer of model. + """ + + sample: torch.FloatTensor + + class UNet2DModel(ModelMixin, ConfigMixin): @register_to_config def __init__( @@ -118,8 +131,11 @@ class UNet2DModel(ModelMixin, ConfigMixin): self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1) def forward( - self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int] - ) -> Dict[str, torch.FloatTensor]: + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + return_dict: bool = True, + ) -> Union[UNet2DOutput, Tuple]: # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -181,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:])))) sample = sample / timesteps - output = {"sample": sample} + if not return_dict: + return (sample,) - return output + return UNet2DOutput(sample=sample) diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index a6f6eb79..e9aeb5eb 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -1,14 +1,27 @@ -from typing import Dict, Optional, Tuple, Union +from dataclasses import dataclass +from typing import Optional, Tuple, Union import torch import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin +from ..utils import BaseOutput from .embeddings import TimestepEmbedding, Timesteps from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor + + class UNet2DConditionModel(ModelMixin, ConfigMixin): @register_to_config def __init__( @@ -125,7 +138,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, - ) -> Dict[str, torch.FloatTensor]: + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 @@ -183,6 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): sample = self.conv_act(sample) sample = self.conv_out(sample) - output = {"sample": sample} + if not return_dict: + return (sample,) - return output + return UNet2DConditionOutput(sample=sample) diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py index d9fe8960..2897a3a0 100644 --- a/src/diffusers/models/vae.py +++ b/src/diffusers/models/vae.py @@ -1,4 +1,5 @@ -from typing import Optional, Tuple +from dataclasses import dataclass +from typing import Optional, Tuple, Union import numpy as np import torch @@ -6,9 +7,50 @@ import torch.nn as nn from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin +from ..utils import BaseOutput from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block +@dataclass +class DecoderOutput(BaseOutput): + """ + Output of decoding method. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Decoded output sample of the model. Output of the last layer of the model. + """ + + sample: torch.FloatTensor + + +@dataclass +class VQEncoderOutput(BaseOutput): + """ + Output of VQModel encoding method. + + Args: + latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Encoded output sample of the model. Output of the last layer of the model. + """ + + latents: torch.FloatTensor + + +@dataclass +class AutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" + + class Encoder(nn.Module): def __init__( self, @@ -369,12 +411,18 @@ class VQModel(ModelMixin, ConfigMixin): act_fn=act_fn, ) - def encode(self, x): + def encode(self, x, return_dict: bool = True): h = self.encoder(x) h = self.quant_conv(h) - return h - def decode(self, h, force_not_quantize=False): + if not return_dict: + return (h,) + + return VQEncoderOutput(latents=h) + + def decode( + self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: # also go through quantization layer if not force_not_quantize: quant, emb_loss, info = self.quantize(h) @@ -382,13 +430,21 @@ class VQModel(ModelMixin, ConfigMixin): quant = h quant = self.post_quant_conv(quant) dec = self.decoder(quant) - return dec - def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: x = sample - h = self.encode(x) - dec = self.decode(h) - return dec + h = self.encode(x).latents + dec = self.decode(h).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) class AutoencoderKL(ModelMixin, ConfigMixin): @@ -431,23 +487,37 @@ class AutoencoderKL(ModelMixin, ConfigMixin): self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1) - def encode(self, x): + def encode(self, x, return_dict: bool = True): h = self.encoder(x) moments = self.quant_conv(h) posterior = DiagonalGaussianDistribution(moments) - return posterior - def decode(self, z): + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: z = self.post_quant_conv(z) dec = self.decoder(z) - return dec - def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor: + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: x = sample - posterior = self.encode(x) + posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample() else: z = posterior.mode() - dec = self.decode(z) - return dec + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 214133bc..fc2bc7bc 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -17,16 +17,19 @@ import importlib import inspect import os -from typing import Optional, Union +from dataclasses import dataclass +from typing import List, Optional, Union +import numpy as np import torch +import PIL from huggingface_hub import snapshot_download from PIL import Image from tqdm.auto import tqdm from .configuration_utils import ConfigMixin -from .utils import DIFFUSERS_CACHE, logging +from .utils import DIFFUSERS_CACHE, BaseOutput, logging INDEX_FILE = "diffusion_pytorch_model.bin" @@ -54,6 +57,20 @@ for library in LOADABLE_CLASSES: ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) +@dataclass +class ImagePipelineOutput(BaseOutput): + """ + Output class for image pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + + class DiffusionPipeline(ConfigMixin): config_name = "model_index.json" diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index 68fce06c..529f89fe 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -94,7 +94,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt)["sample"][0] + image = pipe(prompt).images[0] image.save("astronaut_rides_horse.png") ``` @@ -130,7 +130,7 @@ init_image = init_image.resize((768, 512)) prompt = "A fantasy landscape, trending on artstation" with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"] + images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images images[0].save("fantasy_landscape.png") ``` @@ -174,7 +174,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained( prompt = "a cat sitting on a bench" with autocast("cuda"): - images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"] + images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images images[0].save("cat_on_bench.png") ``` diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 03a9c52b..3dbe32b4 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -15,10 +15,11 @@ import warnings +from typing import Tuple, Union import torch -from ...pipeline_utils import DiffusionPipeline +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DDIMPipeline(DiffusionPipeline): @@ -28,7 +29,16 @@ class DDIMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs): + def __call__( + self, + batch_size=1, + generator=None, + eta=0.0, + num_inference_steps=50, + output_type="pil", + return_dict: bool = True, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: if "torch_device" in kwargs: device = kwargs.pop("torch_device") @@ -56,15 +66,18 @@ class DDIMPipeline(DiffusionPipeline): for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output - model_output = self.unet(image, t)["sample"] + model_output = self.unet(image, t).sample # 2. predict previous mean of image x_t-1 and add variance depending on eta # do x_t -> x_t-1 - image = self.scheduler.step(model_output, t, image, eta)["prev_sample"] + image = self.scheduler.step(model_output, t, image, eta).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image} + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 5d735a39..9111d9e1 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -15,10 +15,11 @@ import warnings +from typing import Tuple, Union import torch -from ...pipeline_utils import DiffusionPipeline +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput class DDPMPipeline(DiffusionPipeline): @@ -28,7 +29,9 @@ class DDPMPipeline(DiffusionPipeline): self.register_modules(unet=unet, scheduler=scheduler) @torch.no_grad() - def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs): + def __call__( + self, batch_size=1, generator=None, output_type="pil", return_dict: bool = True, **kwargs + ) -> Union[ImagePipelineOutput, Tuple]: if "torch_device" in kwargs: device = kwargs.pop("torch_device") warnings.warn( @@ -53,14 +56,17 @@ class DDPMPipeline(DiffusionPipeline): for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output - model_output = self.unet(image, t)["sample"] + model_output = self.unet(image, t).sample # 2. compute previous image: x_t -> t_t-1 - image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"] + image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image} + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index a348d9c0..efc847db 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging -from ...pipeline_utils import DiffusionPipeline +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput class LDMTextToImagePipeline(DiffusionPipeline): @@ -32,8 +32,9 @@ class LDMTextToImagePipeline(DiffusionPipeline): eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", + return_dict: bool = True, **kwargs, - ): + ) -> Union[Tuple, ImagePipelineOutput]: # eta corresponds to η in paper and should be between [0, 1] if "torch_device" in kwargs: @@ -95,25 +96,28 @@ class LDMTextToImagePipeline(DiffusionPipeline): context = torch.cat([uncond_embeddings, text_embeddings]) # predict the noise residual - noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"] + noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample # perform guidance if guidance_scale != 1.0: noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vqvae.decode(latents) + image = self.vqvae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image} + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) ################################################################################ @@ -525,7 +529,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index b4d6db58..0ba4cfde 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -1,11 +1,11 @@ import inspect import warnings -from typing import Optional +from typing import Optional, Tuple, Union import torch from ...models import UNet2DModel, VQModel -from ...pipeline_utils import DiffusionPipeline +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import DDIMScheduler @@ -28,8 +28,9 @@ class LDMPipeline(DiffusionPipeline): eta: float = 0.0, num_inference_steps: int = 50, output_type: Optional[str] = "pil", + return_dict: bool = True, **kwargs, - ): + ) -> Union[Tuple, ImagePipelineOutput]: # eta corresponds to η in paper and should be between [0, 1] if "torch_device" in kwargs: @@ -61,16 +62,19 @@ class LDMPipeline(DiffusionPipeline): for t in self.progress_bar(self.scheduler.timesteps): # predict the noise residual - noise_prediction = self.unet(latents, t)["sample"] + noise_prediction = self.unet(latents, t).sample # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample # decode the image latents with the VAE - image = self.vqvae.decode(latents) + image = self.vqvae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image} + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index 632e3177..7b43c82a 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -15,12 +15,12 @@ import warnings -from typing import Optional +from typing import Optional, Tuple, Union import torch from ...models import UNet2DModel -from ...pipeline_utils import DiffusionPipeline +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import PNDMScheduler @@ -40,8 +40,9 @@ class PNDMPipeline(DiffusionPipeline): num_inference_steps: int = 50, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", + return_dict: bool = True, **kwargs, - ): + ) -> Union[ImagePipelineOutput, Tuple]: # For more information on the sampling method you can take a look at Algorithm 2 of # the official paper: https://arxiv.org/pdf/2202.09778.pdf @@ -66,13 +67,16 @@ class PNDMPipeline(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) for t in self.progress_bar(self.scheduler.timesteps): - model_output = self.unet(image, t)["sample"] + model_output = self.unet(image, t).sample - image = self.scheduler.step(model_output, t, image)["prev_sample"] + image = self.scheduler.step(model_output, t, image).prev_sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image} + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 61008272..088f3141 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -1,12 +1,11 @@ #!/usr/bin/env python3 import warnings -from typing import Optional +from typing import Optional, Tuple, Union import torch -from diffusers import DiffusionPipeline - from ...models import UNet2DModel +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import ScoreSdeVeScheduler @@ -26,8 +25,9 @@ class ScoreSdeVePipeline(DiffusionPipeline): num_inference_steps: int = 2000, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", + return_dict: bool = True, **kwargs, - ): + ) -> Union[ImagePipelineOutput, Tuple]: if "torch_device" in kwargs: device = kwargs.pop("torch_device") warnings.warn( @@ -56,18 +56,21 @@ class ScoreSdeVePipeline(DiffusionPipeline): # correction step for _ in range(self.scheduler.correct_steps): - model_output = self.unet(sample, sigma_t)["sample"] - sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"] + model_output = self.unet(sample, sigma_t).sample + sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample # prediction step - model_output = model(sample, sigma_t)["sample"] + model_output = model(sample, sigma_t).sample output = self.scheduler.step_pred(model_output, t, sample, generator=generator) - sample, sample_mean = output["prev_sample"], output["prev_sample_mean"] + sample, sample_mean = output.prev_sample, output.prev_sample_mean sample = sample_mean.clamp(0, 1) sample = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": sample = self.numpy_to_pil(sample) - return {"sample": sample} + if not return_dict: + return (sample,) + + return ImagePipelineOutput(images=sample) diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md index bb6aced0..64f17a3f 100644 --- a/src/diffusers/pipelines/stable_diffusion/README.md +++ b/src/diffusers/pipelines/stable_diffusion/README.md @@ -67,7 +67,7 @@ pipe = pipe.to("cuda") prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt)["sample"][0] + image = pipe(prompt).sample[0] image.save("astronaut_rides_horse.png") ``` @@ -89,7 +89,7 @@ pipe = StableDiffusionPipeline.from_pretrained( prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt)["sample"][0] + image = pipe(prompt).sample[0] image.save("astronaut_rides_horse.png") ``` @@ -115,7 +115,7 @@ pipe = StableDiffusionPipeline.from_pretrained( prompt = "a photo of an astronaut riding a horse on mars" with autocast("cuda"): - image = pipe(prompt)["sample"][0] + image = pipe(prompt).sample[0] image.save("astronaut_rides_horse.png") ``` diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 1721caf0..8bfa394c 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -1,5 +1,31 @@ # flake8: noqa -from ...utils import is_transformers_available +from dataclasses import dataclass +from typing import List, Union + +import numpy as np + +import PIL +from PIL import Image + +from ...utils import BaseOutput, is_transformers_available + + +@dataclass +class StableDiffusionPipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + nsfw_content_detected (`List[bool]`) + List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content. + """ + + images: Union[List[PIL.Image.Image], np.ndarray] + nsfw_content_detected: List[bool] if is_transformers_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index d4290da6..36ef4a19 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -9,6 +9,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -47,6 +48,7 @@ class StableDiffusionPipeline(DiffusionPipeline): generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", + return_dict: bool = True, **kwargs, ): if "torch_device" in kwargs: @@ -141,7 +143,7 @@ class StableDiffusionPipeline(DiffusionPipeline): latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: @@ -150,13 +152,13 @@ class StableDiffusionPipeline(DiffusionPipeline): # compute the previous noisy sample x_t -> x_t-1 if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vae.decode(latents) + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() @@ -168,4 +170,7 @@ class StableDiffusionPipeline(DiffusionPipeline): if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image, "nsfw_content_detected": has_nsfw_concept} + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 6c64f45c..84b42761 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -10,6 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -57,6 +58,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", + return_dict: bool = True, ): if isinstance(prompt, str): @@ -83,7 +85,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): init_image = preprocess(init_image) # encode the init image into latents and scale the latents - init_latents = self.vae.encode(init_image.to(self.device)).sample(generator=generator) + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) init_latents = 0.18215 * init_latents # expand init_latents for batch_size @@ -158,7 +161,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): t = t.to(self.unet.dtype) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: @@ -167,13 +170,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): # compute the previous noisy sample x_t -> x_t-1 if isinstance(self.scheduler, LMSDiscreteScheduler): - latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample else: - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vae.decode(latents.to(self.vae.dtype)) + image = self.vae.decode(latents.to(self.vae.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() @@ -185,4 +188,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline): if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image, "nsfw_content_detected": has_nsfw_concept} + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 68278467..ca054adb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -11,6 +11,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, PNDMScheduler +from . import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker @@ -72,6 +73,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", + return_dict: bool = True, ): if isinstance(prompt, str): @@ -98,7 +100,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): init_image = preprocess_image(init_image).to(self.device) # encode the init image into latents and scale the latents - init_latents = self.vae.encode(init_image).sample(generator=generator) + init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist + init_latents = init_latent_dist.sample(generator=generator) + init_latents = 0.18215 * init_latents # Expand init_latents for batch_size @@ -166,7 +170,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample # perform guidance if do_classifier_free_guidance: @@ -174,7 +178,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"] + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # masking init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t) @@ -182,7 +186,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): # scale and decode the image latents with vae latents = 1 / 0.18215 * latents - image = self.vae.decode(latents) + image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() @@ -194,4 +198,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline): if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image, "nsfw_content_detected": has_nsfw_concept} + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py index 007395a1..c35e04d1 100644 --- a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -1,11 +1,11 @@ #!/usr/bin/env python3 import warnings -from typing import Optional +from typing import Optional, Tuple, Union import torch from ...models import UNet2DModel -from ...pipeline_utils import DiffusionPipeline +from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ...schedulers import KarrasVeScheduler @@ -35,8 +35,9 @@ class KarrasVePipeline(DiffusionPipeline): num_inference_steps: int = 50, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", + return_dict: bool = True, **kwargs, - ): + ) -> Union[Tuple, ImagePipelineOutput]: if "torch_device" in kwargs: device = kwargs.pop("torch_device") warnings.warn( @@ -71,7 +72,7 @@ class KarrasVePipeline(DiffusionPipeline): # 3. Predict the noise residual given the noise magnitude `sigma_hat` # The model inputs and output are adjusted by following eq. (213) in [1]. - model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"] + model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample # 4. Evaluate dx/dt at sigma_hat # 5. Take Euler step from sigma to sigma_prev @@ -80,20 +81,23 @@ class KarrasVePipeline(DiffusionPipeline): if sigma_prev != 0: # 6. Apply 2nd order correction # The model inputs and output are adjusted by following eq. (213) in [1]. - model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"] + model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample step_output = self.scheduler.step_correct( model_output, sigma_hat, sigma_prev, sample_hat, - step_output["prev_sample"], + step_output.prev_sample, step_output["derivative"], ) - sample = step_output["prev_sample"] + sample = step_output.prev_sample sample = (sample / 2 + 0.5).clamp(0, 1) - sample = sample.cpu().permute(0, 2, 3, 1).numpy() + image = sample.cpu().permute(0, 2, 3, 1).numpy() if output_type == "pil": - sample = self.numpy_to_pil(sample) + image = self.numpy_to_pil(sample) - return {"sample": sample} + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 1f0003f5..78c2f735 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -16,13 +16,13 @@ # and https://github.com/hojonathanho/diffusion import math -from typing import Optional, Union +from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -116,7 +116,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, - ): + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" @@ -174,7 +176,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): prev_sample = prev_sample + variance - return {"prev_sample": prev_sample} + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index 789c9b2a..b0f4b081 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -15,13 +15,13 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -from typing import Optional, Union +from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -135,7 +135,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): sample: Union[torch.FloatTensor, np.ndarray], predict_epsilon=True, generator=None, - ): + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + t = timestep if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: @@ -177,7 +179,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): pred_prev_sample = pred_prev_sample + variance - return {"prev_sample": pred_prev_sample} + if not return_dict: + return (pred_prev_sample,) + + return SchedulerOutput(prev_sample=pred_prev_sample) def add_noise( self, diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 320c682c..8c7d772f 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -13,15 +13,34 @@ # limitations under the License. -from typing import Union +from dataclasses import dataclass +from typing import Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput from .scheduling_utils import SchedulerMixin +@dataclass +class KarrasVeOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Derivate of predicted original image sample (x_0). + """ + + prev_sample: torch.FloatTensor + derivative: torch.FloatTensor + + class KarrasVeScheduler(SchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and @@ -102,12 +121,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): sigma_hat: float, sigma_prev: float, sample_hat: Union[torch.FloatTensor, np.ndarray], - ): + return_dict: bool = True, + ) -> Union[KarrasVeOutput, Tuple]: + pred_original_sample = sample_hat + sigma_hat * model_output derivative = (sample_hat - pred_original_sample) / sigma_hat sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative - return {"prev_sample": sample_prev, "derivative": derivative} + if not return_dict: + return (sample_prev, derivative) + + return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) def step_correct( self, @@ -117,11 +141,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): sample_hat: Union[torch.FloatTensor, np.ndarray], sample_prev: Union[torch.FloatTensor, np.ndarray], derivative: Union[torch.FloatTensor, np.ndarray], - ): + return_dict: bool = True, + ) -> Union[KarrasVeOutput, Tuple]: + pred_original_sample = sample_prev + sigma_prev * model_output derivative_corr = (sample_prev - pred_original_sample) / sigma_prev sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr) - return {"prev_sample": sample_prev, "derivative": derivative_corr} + + if not return_dict: + return (sample_prev, derivative) + + return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative) def add_noise(self, original_samples, noise, timesteps): raise NotImplementedError() diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index e6adcaac..8a6ce0d9 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union +from typing import Tuple, Union import numpy as np import torch @@ -20,7 +20,7 @@ import torch from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import SchedulerMixin, SchedulerOutput class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): @@ -100,7 +100,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): timestep: int, sample: Union[torch.FloatTensor, np.ndarray], order: int = 4, - ): + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: sigma = self.sigmas[timestep] # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise @@ -121,7 +122,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives)) ) - return {"prev_sample": prev_sample} + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) def add_noise(self, original_samples, noise, timesteps): sigmas = self.match_shape(self.sigmas[timesteps], noise) diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index 57ed4fb7..a8778fed 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,13 +15,13 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -from typing import Union +from typing import Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin +from .scheduling_utils import SchedulerMixin, SchedulerOutput def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): @@ -133,18 +133,21 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], - ): + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps: - return self.step_prk(model_output=model_output, timestep=timestep, sample=sample) + return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) else: - return self.step_plms(model_output=model_output, timestep=timestep, sample=sample) + return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict) def step_prk( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], - ): + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the solution to the differential equation. @@ -176,14 +179,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) self.counter += 1 - return {"prev_sample": prev_sample} + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) def step_plms( self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, sample: Union[torch.FloatTensor, np.ndarray], - ): + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: """ Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple times to approximate the solution. @@ -226,7 +233,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) self.counter += 1 - return {"prev_sample": prev_sample} + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) def _get_prev_sample(self, sample, timestep, timestep_prev, model_output): # See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index e3fec035..f6b0ba93 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -15,13 +15,32 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch import warnings -from typing import Optional, Union +from dataclasses import dataclass +from typing import Optional, Tuple, Union import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils import SchedulerMixin +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +@dataclass +class SdeVeOutput(BaseOutput): + """ + Output class for the ScoreSdeVeScheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps. + """ + + prev_sample: torch.FloatTensor + prev_sample_mean: torch.FloatTensor class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): @@ -117,8 +136,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): timestep: int, sample: Union[torch.FloatTensor, np.ndarray], generator: Optional[torch.Generator] = None, + return_dict: bool = True, **kwargs, - ): + ) -> Union[SdeVeOutput, Tuple]: """ Predict the sample at the previous timestep by reversing the SDE. """ @@ -150,15 +170,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): # TODO is the variable diffusion the correct scaling term for the noise? prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g - return {"prev_sample": prev_sample, "prev_sample_mean": prev_sample_mean} + if not return_dict: + return (prev_sample, prev_sample_mean) + + return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean) def step_correct( self, model_output: Union[torch.FloatTensor, np.ndarray], sample: Union[torch.FloatTensor, np.ndarray], generator: Optional[torch.Generator] = None, + return_dict: bool = True, **kwargs, - ): + ) -> Union[SchedulerOutput, Tuple]: """ Correct the predicted sample based on the output model_output of the network. This is often run repeatedly after making the prediction for the previous timestep. @@ -186,7 +210,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): prev_sample_mean = sample + step_size[:, None, None, None] * model_output prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise - return {"prev_sample": prev_sample} + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) def __len__(self): return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index b0cd4bda..7d176e63 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,15 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Union import numpy as np import torch +from ..utils import BaseOutput + SCHEDULER_CONFIG_NAME = "scheduler_config.json" +@dataclass +class SchedulerOutput(BaseOutput): + """ + Base class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + class SchedulerMixin: config_name = SCHEDULER_CONFIG_NAME diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index d4e31b84..f9172e8d 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -33,6 +33,7 @@ from .import_utils import ( requires_backends, ) from .logging import get_logger +from .outputs import BaseOutput logger = get_logger(__name__) diff --git a/src/diffusers/utils/outputs.py b/src/diffusers/utils/outputs.py new file mode 100644 index 00000000..d8e695db --- /dev/null +++ b/src/diffusers/utils/outputs.py @@ -0,0 +1,139 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Generic utilities +""" + +import warnings +from collections import OrderedDict +from dataclasses import fields +from typing import Any, Tuple + +import numpy as np + +from .import_utils import is_torch_available + + +def is_tensor(x): + """ + Tests if `x` is a `torch.Tensor` or `np.ndarray`. + """ + if is_torch_available(): + import torch + + if isinstance(x, torch.Tensor): + return True + + return isinstance(x, np.ndarray) + + +class BaseOutput(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + python dictionary. + + + + You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __post_init__(self): + class_fields = fields(self) + + # Safety and consistency checks + if not len(class_fields): + raise ValueError(f"{self.__class__.__name__} has no fields.") + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and not is_tensor(first_field): + if isinstance(first_field, dict): + iterator = first_field.items() + first_field_iterator = True + else: + try: + iterator = iter(first_field) + first_field_iterator = True + except TypeError: + first_field_iterator = False + + # if we provided an iterator as first field and the iterator is a (key, value) iterator + # set the associated fields + if first_field_iterator: + for element in iterator: + if ( + not isinstance(element, (list, tuple)) + or not len(element) == 2 + or not isinstance(element[0], str) + ): + break + setattr(self, element[0], element[1]) + if element[1] is not None: + self[element[0]] = element[1] + elif first_field is not None: + self[class_fields[0].name] = first_field + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = {k: v for (k, v) in self.items()} + if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample": + warnings.warn( + "The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or" + " `'images'` instead.", + DeprecationWarning, + ) + return inner_dict["images"] + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(self[k] for k in self.keys()) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index e1923805..8c7c6312 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -15,6 +15,7 @@ import inspect import tempfile +from typing import Dict, List, Tuple import numpy as np import torch @@ -39,12 +40,12 @@ class ModelTesterMixin: with torch.no_grad(): image = model(**inputs_dict) if isinstance(image, dict): - image = image["sample"] + image = image.sample new_image = new_model(**inputs_dict) if isinstance(new_image, dict): - new_image = new_image["sample"] + new_image = new_image.sample max_diff = (image - new_image).abs().sum().item() self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes") @@ -57,11 +58,11 @@ class ModelTesterMixin: with torch.no_grad(): first = model(**inputs_dict) if isinstance(first, dict): - first = first["sample"] + first = first.sample second = model(**inputs_dict) if isinstance(second, dict): - second = second["sample"] + second = second.sample out_1 = first.cpu().numpy() out_2 = second.cpu().numpy() @@ -80,7 +81,7 @@ class ModelTesterMixin: output = model(**inputs_dict) if isinstance(output, dict): - output = output["sample"] + output = output.sample self.assertIsNotNone(output) expected_shape = inputs_dict["sample"].shape @@ -122,12 +123,12 @@ class ModelTesterMixin: output_1 = model(**inputs_dict) if isinstance(output_1, dict): - output_1 = output_1["sample"] + output_1 = output_1.sample output_2 = new_model(**inputs_dict) if isinstance(output_2, dict): - output_2 = output_2["sample"] + output_2 = output_2.sample self.assertEqual(output_1.shape, output_2.shape) @@ -140,7 +141,7 @@ class ModelTesterMixin: output = model(**inputs_dict) if isinstance(output, dict): - output = output["sample"] + output = output.sample noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) @@ -157,9 +158,47 @@ class ModelTesterMixin: output = model(**inputs_dict) if isinstance(output, dict): - output = output["sample"] + output = output.sample noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) loss.backward() ema_model.step(model) + + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ), + ) + + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + outputs_dict = model(**inputs_dict) + outputs_tuple = model(**inputs_dict, return_dict=False) + + recursive_check(outputs_tuple, outputs_dict) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 47f562dd..c574a009 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -77,7 +77,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase): # time_step = torch.tensor([10]) # # with torch.no_grad(): -# output = model(noise, time_step)["sample"] +# output = model(noise, time_step).sample # # output_slice = output[0, -1, -3:, -3:].flatten() # fmt: off @@ -129,7 +129,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): self.assertEqual(len(loading_info["missing_keys"]), 0) model.to(torch_device) - image = model(**self.dummy_input)["sample"] + image = model(**self.dummy_input).sample assert image is not None, "Make sure output is not None" @@ -147,7 +147,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): time_step = torch.tensor([10] * noise.shape[0]).to(torch_device) with torch.no_grad(): - output = model(noise, time_step)["sample"] + output = model(noise, time_step).sample output_slice = output[0, -1, -3:, -3:].flatten().cpu() # fmt: off @@ -258,7 +258,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) with torch.no_grad(): - output = model(noise, time_step)["sample"] + output = model(noise, time_step).sample output_slice = output[0, -3:, -3:, -1].flatten().cpu() # fmt: off @@ -283,7 +283,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase): time_step = torch.tensor(batch_size * [1e-4]).to(torch_device) with torch.no_grad(): - output = model(noise, time_step)["sample"] + output = model(noise, time_step).sample output_slice = output[0, -3:, -3:, -1].flatten().cpu() # fmt: off diff --git a/tests/test_models_vae.py b/tests/test_models_vae.py index 7df6b42b..adf9767d 100644 --- a/tests/test_models_vae.py +++ b/tests/test_models_vae.py @@ -87,7 +87,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase): image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = image.to(torch_device) with torch.no_grad(): - output = model(image, sample_posterior=True) + output = model(image, sample_posterior=True).sample output_slice = output[0, -1, -3:, -3:].flatten().cpu() # fmt: off diff --git a/tests/test_models_vq.py b/tests/test_models_vq.py index 4040bad9..c0accecc 100644 --- a/tests/test_models_vq.py +++ b/tests/test_models_vq.py @@ -85,7 +85,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase): image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size) image = image.to(torch_device) with torch.no_grad(): - output = model(image) + output = model(image).sample output_slice = output[0, -1, -3:, -3:].flatten().cpu() # fmt: off diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 257517a1..fbd0faf0 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -67,12 +67,12 @@ def test_progress_bar(capsys): scheduler = DDPMScheduler(num_train_timesteps=10) ddpm = DDPMPipeline(model, scheduler).to(torch_device) - ddpm(output_type="numpy")["sample"] + ddpm(output_type="numpy").images captured = capsys.readouterr() assert "10/10" in captured.err, "Progress bar has to be displayed" ddpm.set_progress_bar_config(disable=True) - ddpm(output_type="numpy")["sample"] + ddpm(output_type="numpy").images captured = capsys.readouterr() assert captured.err == "", "Progress bar should be disabled" @@ -196,15 +196,20 @@ class PipelineFastTests(unittest.TestCase): ddpm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"] + image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = torch.manual_seed(0) + image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array( [1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04] ) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_pndm_cifar10(self): unet = self.dummy_uncond_unet @@ -213,14 +218,20 @@ class PipelineFastTests(unittest.TestCase): pndm = PNDMPipeline(unet=unet, scheduler=scheduler) pndm.to(torch_device) pndm.set_progress_bar_config(disable=None) + generator = torch.manual_seed(0) - image = pndm(generator=generator, num_inference_steps=20, output_type="numpy")["sample"] + image = pndm(generator=generator, num_inference_steps=20, output_type="numpy").images + + generator = torch.manual_seed(0) + image_from_tuple = pndm(generator=generator, num_inference_steps=20, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_ldm_text2img(self): unet = self.dummy_cond_unet @@ -239,11 +250,23 @@ class PipelineFastTests(unittest.TestCase): "sample" ] + generator = torch.manual_seed(0) + image_from_tuple = ldm( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="numpy", + return_dict=False, + )[0] + image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_ddim(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -274,16 +297,28 @@ class PipelineFastTests(unittest.TestCase): sd_pipe.set_progress_bar_config(disable=None) prompt = "A painting of a squirrel eating a burger" + generator = torch.Generator(device=device).manual_seed(0) output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") + image = output.images - image = output["sample"] + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 128, 128, 3) expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_pndm(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -310,13 +345,25 @@ class PipelineFastTests(unittest.TestCase): generator = torch.Generator(device=device).manual_seed(0) output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") - image = output["sample"] + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 128, 128, 3) expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_k_lms(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -343,13 +390,25 @@ class PipelineFastTests(unittest.TestCase): generator = torch.Generator(device=device).manual_seed(0) output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np") - image = output["sample"] + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + return_dict=False, + )[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 128, 128, 3) expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_score_sde_ve_pipeline(self): unet = self.dummy_uncond_unet @@ -360,14 +419,19 @@ class PipelineFastTests(unittest.TestCase): sde_ve.set_progress_bar_config(disable=None) torch.manual_seed(0) - image = sde_ve(num_inference_steps=2, output_type="numpy")["sample"] + image = sde_ve(num_inference_steps=2, output_type="numpy").images + + torch.manual_seed(0) + image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_ldm_uncond(self): unet = self.dummy_uncond_unet @@ -379,13 +443,18 @@ class PipelineFastTests(unittest.TestCase): ldm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = ldm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"] + image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images + + generator = torch.manual_seed(0) + image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 64, 64, 3) expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_karras_ve_pipeline(self): unet = self.dummy_uncond_unet @@ -396,12 +465,18 @@ class PipelineFastTests(unittest.TestCase): pipe.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = pipe(num_inference_steps=2, generator=generator, output_type="numpy")["sample"] + image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images + + generator = torch.manual_seed(0) + image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] + assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_img2img(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -437,13 +512,26 @@ class PipelineFastTests(unittest.TestCase): init_image=init_image, ) - image = output["sample"] + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + return_dict=False, + )[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_img2img_k_lms(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -479,14 +567,27 @@ class PipelineFastTests(unittest.TestCase): output_type="np", init_image=init_image, ) + image = output.images - image = output["sample"] + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + return_dict=False, + ) + image_from_tuple = output[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_inpaint(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator @@ -525,13 +626,27 @@ class PipelineFastTests(unittest.TestCase): mask_image=mask_image, ) - image = output["sample"] + image = output.images + + generator = torch.Generator(device=device).manual_seed(0) + image_from_tuple = sd_pipe( + [prompt], + generator=generator, + guidance_scale=6.0, + num_inference_steps=2, + output_type="np", + init_image=init_image, + mask_image=mask_image, + return_dict=False, + )[0] image_slice = image[0, -3:, -3:, -1] + image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1] assert image.shape == (1, 32, 32, 3) expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153]) assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 class PipelineTesterMixin(unittest.TestCase): @@ -565,9 +680,9 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy")["sample"] + image = ddpm(generator=generator, output_type="numpy").images generator = generator.manual_seed(0) - new_image = new_ddpm(generator=generator, output_type="numpy")["sample"] + new_image = new_ddpm(generator=generator, output_type="numpy").images assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" @@ -586,9 +701,9 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy")["sample"] + image = ddpm(generator=generator, output_type="numpy").images generator = generator.manual_seed(0) - new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"] + new_image = ddpm_from_hub(generator=generator, output_type="numpy").images assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" @@ -610,9 +725,9 @@ class PipelineTesterMixin(unittest.TestCase): generator = torch.manual_seed(0) - image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"] + image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images generator = generator.manual_seed(0) - new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"] + new_image = ddpm_from_hub(generator=generator, output_type="numpy").images assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass" @@ -625,17 +740,17 @@ class PipelineTesterMixin(unittest.TestCase): pipe.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - images = pipe(generator=generator, output_type="numpy")["sample"] + images = pipe(generator=generator, output_type="numpy").images assert images.shape == (1, 32, 32, 3) assert isinstance(images, np.ndarray) - images = pipe(generator=generator, output_type="pil")["sample"] + images = pipe(generator=generator, output_type="pil").images assert isinstance(images, list) assert len(images) == 1 assert isinstance(images[0], PIL.Image.Image) # use PIL by default - images = pipe(generator=generator)["sample"] + images = pipe(generator=generator).images assert isinstance(images, list) assert isinstance(images[0], PIL.Image.Image) @@ -652,7 +767,7 @@ class PipelineTesterMixin(unittest.TestCase): ddpm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy")["sample"] + image = ddpm(generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] @@ -672,7 +787,7 @@ class PipelineTesterMixin(unittest.TestCase): ddpm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = ddpm(generator=generator, output_type="numpy")["sample"] + image = ddpm(generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] @@ -692,7 +807,7 @@ class PipelineTesterMixin(unittest.TestCase): ddim.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"] + image = ddim(generator=generator, eta=0.0, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] @@ -711,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase): pndm.to(torch_device) pndm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = pndm(generator=generator, output_type="numpy")["sample"] + image = pndm(generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] @@ -745,7 +860,7 @@ class PipelineTesterMixin(unittest.TestCase): prompt = "A painting of a squirrel eating a burger" generator = torch.manual_seed(0) - image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy")["sample"] + image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] @@ -768,7 +883,7 @@ class PipelineTesterMixin(unittest.TestCase): [prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np" ) - image = output["sample"] + image = output.images image_slice = image[0, -3:, -3:, -1] @@ -797,7 +912,7 @@ class PipelineTesterMixin(unittest.TestCase): with torch.autocast("cuda"): output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy") - image = output["sample"] + image = output.images image_slice = image[0, -3:, -3:, -1] @@ -817,7 +932,7 @@ class PipelineTesterMixin(unittest.TestCase): sde_ve.set_progress_bar_config(disable=None) torch.manual_seed(0) - image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"] + image = sde_ve(num_inference_steps=300, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] @@ -833,7 +948,7 @@ class PipelineTesterMixin(unittest.TestCase): ldm.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"] + image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] @@ -857,10 +972,10 @@ class PipelineTesterMixin(unittest.TestCase): ddim.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"] + ddpm_image = ddpm(generator=generator, output_type="numpy").images generator = torch.manual_seed(0) - ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"] + ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images # the values aren't exactly equal, but the images look the same visually assert np.abs(ddpm_image - ddim_image).max() < 1e-1 @@ -882,7 +997,7 @@ class PipelineTesterMixin(unittest.TestCase): ddim.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"] + ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images generator = torch.manual_seed(0) ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[ @@ -903,7 +1018,7 @@ class PipelineTesterMixin(unittest.TestCase): pipe.set_progress_bar_config(disable=None) generator = torch.manual_seed(0) - image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"] + image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images image_slice = image[0, -3:, -3:, -1] assert image.shape == (1, 256, 256, 3) @@ -974,9 +1089,8 @@ class PipelineTesterMixin(unittest.TestCase): prompt = "A fantasy landscape, trending on artstation" generator = torch.Generator(device=torch_device).manual_seed(0) - image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)[ - "sample" - ][0] + output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator) + image = output.images[0] expected_array = np.array(output_image) sampled_array = np.array(image) @@ -1008,7 +1122,7 @@ class PipelineTesterMixin(unittest.TestCase): strength=0.75, guidance_scale=7.5, generator=generator, - )["sample"][0] + ).images[0] expected_array = np.array(output_image) sampled_array = np.array(image) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 0ce6715f..3c2e786f 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -14,6 +14,7 @@ # limitations under the License. import tempfile import unittest +from typing import Dict, List, Tuple import numpy as np import torch @@ -85,8 +86,8 @@ class SchedulerCommonTest(unittest.TestCase): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -114,9 +115,9 @@ class SchedulerCommonTest(unittest.TestCase): kwargs["num_inference_steps"] = num_inference_steps torch.manual_seed(0) - output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample torch.manual_seed(0) - new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -143,9 +144,9 @@ class SchedulerCommonTest(unittest.TestCase): kwargs["num_inference_steps"] = num_inference_steps torch.manual_seed(0) - output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] + output = scheduler.step(residual, 1, sample, **kwargs).prev_sample torch.manual_seed(0) - new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] + new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -166,8 +167,8 @@ class SchedulerCommonTest(unittest.TestCase): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"] - output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] + output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample + output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) @@ -195,11 +196,64 @@ class SchedulerCommonTest(unittest.TestCase): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"] - output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"] + output = scheduler.step(residual, 1, sample, **kwargs).prev_sample + output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" + def test_scheduler_outputs_equivalence(self): + def set_nan_tensor_to_zero(t): + t[t != t] = 0 + return t + + def recursive_check(tuple_object, dict_object): + if isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), + msg=( + "Tuple and dict output are not equal. Difference:" + f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:" + f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has" + f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}." + ), + ) + + kwargs = dict(self.forward_default_kwargs) + num_inference_steps = kwargs.pop("num_inference_steps", None) + + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config) + + sample = self.dummy_sample + residual = 0.1 * sample + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_dict = scheduler.step(residual, 0, sample, **kwargs) + + if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): + scheduler.set_timesteps(num_inference_steps) + elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): + kwargs["num_inference_steps"] = num_inference_steps + + outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs) + + recursive_check(outputs_tuple, outputs_dict) + class DDPMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDPMScheduler,) @@ -270,7 +324,7 @@ class DDPMSchedulerTest(SchedulerCommonTest): residual = model(sample, t) # 2. predict previous mean of sample x_t-1 - pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"] + pred_prev_sample = scheduler.step(residual, t, sample).prev_sample # if t > 0: # noise = self.dummy_sample_deter @@ -356,7 +410,7 @@ class DDIMSchedulerTest(SchedulerCommonTest): for t in scheduler.timesteps: residual = model(sample, t) - sample = scheduler.step(residual, t, sample, eta)["prev_sample"] + sample = scheduler.step(residual, t, sample, eta).prev_sample result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) @@ -401,13 +455,13 @@ class PNDMSchedulerTest(SchedulerCommonTest): # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] - output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -438,13 +492,13 @@ class PNDMSchedulerTest(SchedulerCommonTest): # copy over dummy past residual (must be after setting timesteps) new_scheduler.ets = dummy_past_residuals[:] - output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" @@ -476,12 +530,12 @@ class PNDMSchedulerTest(SchedulerCommonTest): scheduler.ets = dummy_past_residuals[:] scheduler_pt.ets = dummy_past_residuals_pt[:] - output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"] - output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"] + output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample + output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" - output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"] - output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"] + output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample + output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical" @@ -535,14 +589,14 @@ class PNDMSchedulerTest(SchedulerCommonTest): dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05] scheduler.ets = dummy_past_residuals[:] - output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"] - output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"] + output_0 = scheduler.step_prk(residual, 0, sample, **kwargs).prev_sample + output_1 = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) - output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"] - output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"] + output_0 = scheduler.step_plms(residual, 0, sample, **kwargs).prev_sample + output_1 = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) @@ -573,7 +627,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): scheduler_config = self.get_scheduler_config() scheduler = scheduler_class(**scheduler_config) - scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"] + scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample def test_full_loop_no_noise(self): scheduler_class = self.scheduler_classes[0] @@ -587,11 +641,11 @@ class PNDMSchedulerTest(SchedulerCommonTest): for i, t in enumerate(scheduler.prk_timesteps): residual = model(sample, t) - sample = scheduler.step_prk(residual, i, sample)["prev_sample"] + sample = scheduler.step_prk(residual, i, sample).prev_sample for i, t in enumerate(scheduler.plms_timesteps): residual = model(sample, t) - sample = scheduler.step_plms(residual, i, sample)["prev_sample"] + sample = scheduler.step_plms(residual, i, sample).prev_sample result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) @@ -664,13 +718,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) - output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"] + output = scheduler.step_correct(residual, sample, **kwargs).prev_sample + new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" @@ -689,13 +743,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): scheduler.save_config(tmpdirname) new_scheduler = scheduler_class.from_config(tmpdirname) - output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"] + output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample + new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" - output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"] - new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"] + output = scheduler.step_correct(residual, sample, **kwargs).prev_sample + new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical" @@ -732,13 +786,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): for _ in range(scheduler.correct_steps): with torch.no_grad(): model_output = model(sample, sigma_t) - sample = scheduler.step_correct(model_output, sample, **kwargs)["prev_sample"] + sample = scheduler.step_correct(model_output, sample, **kwargs).prev_sample with torch.no_grad(): model_output = model(sample, sigma_t) output = scheduler.step_pred(model_output, t, sample, **kwargs) - sample, _ = output["prev_sample"], output["prev_sample_mean"] + sample, _ = output.prev_sample, output.prev_sample_mean result_sum = torch.sum(torch.abs(sample)) result_mean = torch.mean(torch.abs(sample)) @@ -763,8 +817,8 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"): kwargs["num_inference_steps"] = num_inference_steps - output_0 = scheduler.step_pred(residual, 0, sample, **kwargs)["prev_sample"] - output_1 = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"] + output_0 = scheduler.step_pred(residual, 0, sample, **kwargs).prev_sample + output_1 = scheduler.step_pred(residual, 1, sample, **kwargs).prev_sample self.assertEqual(output_0.shape, sample.shape) self.assertEqual(output_0.shape, output_1.shape) diff --git a/tests/test_training.py b/tests/test_training.py index a9d330ef..27caf033 100644 --- a/tests/test_training.py +++ b/tests/test_training.py @@ -66,7 +66,7 @@ class TrainingTests(unittest.TestCase): for i in range(4): optimizer.zero_grad() ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i]) - ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i])["sample"] + ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i]).sample loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i]) loss.backward() optimizer.step() @@ -78,7 +78,7 @@ class TrainingTests(unittest.TestCase): for i in range(4): optimizer.zero_grad() ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i]) - ddim_noise_pred = model(ddim_noisy_images, timesteps[i])["sample"] + ddim_noise_pred = model(ddim_noisy_images, timesteps[i]).sample loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i]) loss.backward() optimizer.step()