[ModelOutputs] Replace dict outputs with Dict/Dataclass and allow to return tuples (#334)

* add outputs for models

* add for pipelines

* finish schedulers

* better naming

* adapt tests as well

* replace dict access with . access

* make schedulers works

* finish

* correct readme

* make  bcp compatible

* up

* small fix

* finish

* more fixes

* more fixes

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Update src/diffusers/models/vae.py

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* Adapt model outputs

* Apply more suggestions

* finish examples

* correct

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Patrick von Platen 2022-09-05 14:49:26 +02:00 committed by GitHub
parent daddd98b88
commit cc59b05635
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 893 additions and 247 deletions

View File

@ -80,7 +80,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```
**Note**: If you don't want to use the token, you can also simply download the model weights
@ -101,7 +101,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```
If you are limited by GPU memory, you might want to consider using the model in `fp16`.
@ -117,7 +117,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
```
Finally, if you wish to use a different scheduler, you can simply instantiate
@ -143,7 +143,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
@ -184,7 +184,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"]
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
@ -228,7 +228,7 @@ pipe = pipe.to(device)
prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```
@ -260,7 +260,7 @@ ldm = DiffusionPipeline.from_pretrained(model_id)
# run pipeline in inference (sample random noise and denoise)
prompt = "A painting of a squirrel eating a burger"
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"]
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images
# save images
for idx, image in enumerate(images):
@ -277,7 +277,7 @@ model_id = "google/ddpm-celebahq-256"
ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
# run pipeline in inference (sample random noise and denoise)
image = ddpm()["sample"]
image = ddpm().images
# save image
image[0].save("ddpm_generated_image.png")

View File

@ -76,7 +76,7 @@ pipe = pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch
prompt = "A <cat-toy> backpack"
with autocast("cuda"):
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)["sample"][0]
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
```

View File

@ -498,7 +498,7 @@ def main():
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).sample().detach()
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
@ -515,7 +515,7 @@ def main():
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"]
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)

View File

@ -139,7 +139,7 @@ def main(args):
with accelerator.accumulate(model):
# Predict the noise residual
noise_pred = model(noisy_images, timesteps)["sample"]
noise_pred = model(noisy_images, timesteps).sample
loss = F.mse_loss(noise_pred, noise)
accelerator.backward(loss)
@ -174,7 +174,7 @@ def main(args):
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")

View File

@ -119,7 +119,7 @@ for mod in models:
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
time_step = torch.tensor([10] * noise.shape[0])
with torch.no_grad():
logits = model(noise, time_step)["sample"]
logits = model(noise, time_step).sample
assert torch.allclose(
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3

View File

@ -19,9 +19,9 @@ import shutil
from pathlib import Path
from typing import Optional
from diffusers import DiffusionPipeline
from huggingface_hub import HfFolder, Repository, whoami
from .pipeline_utils import DiffusionPipeline
from .utils import is_modelcards_available, logging

View File

@ -1,14 +1,27 @@
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class UNet2DOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states output. Output of last layer of model.
"""
sample: torch.FloatTensor
class UNet2DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
@ -118,8 +131,11 @@ class UNet2DModel(ModelMixin, ConfigMixin):
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def forward(
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
) -> Dict[str, torch.FloatTensor]:
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
return_dict: bool = True,
) -> Union[UNet2DOutput, Tuple]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
@ -181,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
sample = sample / timesteps
output = {"sample": sample}
if not return_dict:
return (sample,)
return output
return UNet2DOutput(sample=sample)

View File

@ -1,14 +1,27 @@
from typing import Dict, Optional, Tuple, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
"""
sample: torch.FloatTensor
class UNet2DConditionModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
@ -125,7 +138,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
) -> Dict[str, torch.FloatTensor]:
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
@ -183,6 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
sample = self.conv_act(sample)
sample = self.conv_out(sample)
output = {"sample": sample}
if not return_dict:
return (sample,)
return output
return UNet2DConditionOutput(sample=sample)

View File

@ -1,4 +1,5 @@
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
@ -6,9 +7,50 @@ import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
@dataclass
class DecoderOutput(BaseOutput):
"""
Output of decoding method.
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Decoded output sample of the model. Output of the last layer of the model.
"""
sample: torch.FloatTensor
@dataclass
class VQEncoderOutput(BaseOutput):
"""
Output of VQModel encoding method.
Args:
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Encoded output sample of the model. Output of the last layer of the model.
"""
latents: torch.FloatTensor
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class Encoder(nn.Module):
def __init__(
self,
@ -369,12 +411,18 @@ class VQModel(ModelMixin, ConfigMixin):
act_fn=act_fn,
)
def encode(self, x):
def encode(self, x, return_dict: bool = True):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
if not return_dict:
return (h,)
return VQEncoderOutput(latents=h)
def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
@ -382,13 +430,21 @@ class VQModel(ModelMixin, ConfigMixin):
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample
h = self.encode(x)
dec = self.decode(h)
return dec
h = self.encode(x).latents
dec = self.decode(h).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
class AutoencoderKL(ModelMixin, ConfigMixin):
@ -431,23 +487,37 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
def encode(self, x):
def encode(self, x, return_dict: bool = True):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
x = sample
posterior = self.encode(x)
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@ -17,16 +17,19 @@
import importlib
import inspect
import os
from typing import Optional, Union
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import torch
import PIL
from huggingface_hub import snapshot_download
from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging
from .utils import DIFFUSERS_CACHE, BaseOutput, logging
INDEX_FILE = "diffusion_pytorch_model.bin"
@ -54,6 +57,20 @@ for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
@dataclass
class ImagePipelineOutput(BaseOutput):
"""
Output class for image pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
class DiffusionPipeline(ConfigMixin):
config_name = "model_index.json"

View File

@ -94,7 +94,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).images[0]
image.save("astronaut_rides_horse.png")
```
@ -130,7 +130,7 @@ init_image = init_image.resize((768, 512))
prompt = "A fantasy landscape, trending on artstation"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"]
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
images[0].save("fantasy_landscape.png")
```
@ -174,7 +174,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
prompt = "a cat sitting on a bench"
with autocast("cuda"):
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
images[0].save("cat_on_bench.png")
```

View File

@ -15,10 +15,11 @@
import warnings
from typing import Tuple, Union
import torch
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class DDIMPipeline(DiffusionPipeline):
@ -28,7 +29,16 @@ class DDIMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
def __call__(
self,
batch_size=1,
generator=None,
eta=0.0,
num_inference_steps=50,
output_type="pil",
return_dict: bool = True,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
@ -56,15 +66,18 @@ class DDIMPipeline(DiffusionPipeline):
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t)["sample"]
model_output = self.unet(image, t).sample
# 2. predict previous mean of image x_t-1 and add variance depending on eta
# do x_t -> x_t-1
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
image = self.scheduler.step(model_output, t, image, eta).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@ -15,10 +15,11 @@
import warnings
from typing import Tuple, Union
import torch
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class DDPMPipeline(DiffusionPipeline):
@ -28,7 +29,9 @@ class DDPMPipeline(DiffusionPipeline):
self.register_modules(unet=unet, scheduler=scheduler)
@torch.no_grad()
def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
def __call__(
self, batch_size=1, generator=None, output_type="pil", return_dict: bool = True, **kwargs
) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
@ -53,14 +56,17 @@ class DDPMPipeline(DiffusionPipeline):
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t)["sample"]
model_output = self.unet(image, t).sample
# 2. compute previous image: x_t -> t_t-1
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class LDMTextToImagePipeline(DiffusionPipeline):
@ -32,8 +32,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[Tuple, ImagePipelineOutput]:
# eta corresponds to η in paper and should be between [0, 1]
if "torch_device" in kwargs:
@ -95,25 +96,28 @@ class LDMTextToImagePipeline(DiffusionPipeline):
context = torch.cat([uncond_embeddings, text_embeddings])
# predict the noise residual
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
# perform guidance
if guidance_scale != 1.0:
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vqvae.decode(latents)
image = self.vqvae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
################################################################################
@ -525,7 +529,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (

View File

@ -1,11 +1,11 @@
import inspect
import warnings
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from ...models import UNet2DModel, VQModel
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler
@ -28,8 +28,9 @@ class LDMPipeline(DiffusionPipeline):
eta: float = 0.0,
num_inference_steps: int = 50,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[Tuple, ImagePipelineOutput]:
# eta corresponds to η in paper and should be between [0, 1]
if "torch_device" in kwargs:
@ -61,16 +62,19 @@ class LDMPipeline(DiffusionPipeline):
for t in self.progress_bar(self.scheduler.timesteps):
# predict the noise residual
noise_prediction = self.unet(latents, t)["sample"]
noise_prediction = self.unet(latents, t).sample
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
# decode the image latents with the VAE
image = self.vqvae.decode(latents)
image = self.vqvae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@ -15,12 +15,12 @@
import warnings
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import PNDMScheduler
@ -40,8 +40,9 @@ class PNDMPipeline(DiffusionPipeline):
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[ImagePipelineOutput, Tuple]:
# For more information on the sampling method you can take a look at Algorithm 2 of
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
@ -66,13 +67,16 @@ class PNDMPipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
model_output = self.unet(image, t)["sample"]
model_output = self.unet(image, t).sample
image = self.scheduler.step(model_output, t, image)["prev_sample"]
image = self.scheduler.step(model_output, t, image).prev_sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@ -1,12 +1,11 @@
#!/usr/bin/env python3
import warnings
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from diffusers import DiffusionPipeline
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import ScoreSdeVeScheduler
@ -26,8 +25,9 @@ class ScoreSdeVePipeline(DiffusionPipeline):
num_inference_steps: int = 2000,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[ImagePipelineOutput, Tuple]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
@ -56,18 +56,21 @@ class ScoreSdeVePipeline(DiffusionPipeline):
# correction step
for _ in range(self.scheduler.correct_steps):
model_output = self.unet(sample, sigma_t)["sample"]
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
model_output = self.unet(sample, sigma_t).sample
sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
# prediction step
model_output = model(sample, sigma_t)["sample"]
model_output = model(sample, sigma_t).sample
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
sample, sample_mean = output.prev_sample, output.prev_sample_mean
sample = sample_mean.clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)
return {"sample": sample}
if not return_dict:
return (sample,)
return ImagePipelineOutput(images=sample)

View File

@ -67,7 +67,7 @@ pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
@ -89,7 +89,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```
@ -115,7 +115,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
prompt = "a photo of an astronaut riding a horse on mars"
with autocast("cuda"):
image = pipe(prompt)["sample"][0]
image = pipe(prompt).sample[0]
image.save("astronaut_rides_horse.png")
```

View File

@ -1,5 +1,31 @@
# flake8: noqa
from ...utils import is_transformers_available
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import BaseOutput, is_transformers_available
@dataclass
class StableDiffusionPipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
nsfw_content_detected (`List[bool]`)
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
if is_transformers_available():

View File

@ -9,6 +9,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@ -47,6 +48,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
if "torch_device" in kwargs:
@ -141,7 +143,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
@ -150,13 +152,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
@ -168,4 +170,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -10,6 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@ -57,6 +58,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
if isinstance(prompt, str):
@ -83,7 +85,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
init_image = preprocess(init_image)
# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image.to(self.device)).sample(generator=generator)
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# expand init_latents for batch_size
@ -158,7 +161,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
t = t.to(self.unet.dtype)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
@ -167,13 +170,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents.to(self.vae.dtype))
image = self.vae.decode(latents.to(self.vae.dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
@ -185,4 +188,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -11,6 +11,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, PNDMScheduler
from . import StableDiffusionPipelineOutput
from .safety_checker import StableDiffusionSafetyChecker
@ -72,6 +73,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
eta: Optional[float] = 0.0,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
if isinstance(prompt, str):
@ -98,7 +100,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
init_image = preprocess_image(init_image).to(self.device)
# encode the init image into latents and scale the latents
init_latents = self.vae.encode(init_image).sample(generator=generator)
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
init_latents = init_latent_dist.sample(generator=generator)
init_latents = 0.18215 * init_latents
# Expand init_latents for batch_size
@ -166,7 +170,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
# perform guidance
if do_classifier_free_guidance:
@ -174,7 +178,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
@ -182,7 +186,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
# scale and decode the image latents with vae
latents = 1 / 0.18215 * latents
image = self.vae.decode(latents)
image = self.vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
@ -194,4 +198,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
if output_type == "pil":
image = self.numpy_to_pil(image)
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -1,11 +1,11 @@
#!/usr/bin/env python3
import warnings
from typing import Optional
from typing import Optional, Tuple, Union
import torch
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...schedulers import KarrasVeScheduler
@ -35,8 +35,9 @@ class KarrasVePipeline(DiffusionPipeline):
num_inference_steps: int = 50,
generator: Optional[torch.Generator] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
**kwargs,
):
) -> Union[Tuple, ImagePipelineOutput]:
if "torch_device" in kwargs:
device = kwargs.pop("torch_device")
warnings.warn(
@ -71,7 +72,7 @@ class KarrasVePipeline(DiffusionPipeline):
# 3. Predict the noise residual given the noise magnitude `sigma_hat`
# The model inputs and output are adjusted by following eq. (213) in [1].
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"]
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
# 4. Evaluate dx/dt at sigma_hat
# 5. Take Euler step from sigma to sigma_prev
@ -80,20 +81,23 @@ class KarrasVePipeline(DiffusionPipeline):
if sigma_prev != 0:
# 6. Apply 2nd order correction
# The model inputs and output are adjusted by following eq. (213) in [1].
model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"]
model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
step_output = self.scheduler.step_correct(
model_output,
sigma_hat,
sigma_prev,
sample_hat,
step_output["prev_sample"],
step_output.prev_sample,
step_output["derivative"],
)
sample = step_output["prev_sample"]
sample = step_output.prev_sample
sample = (sample / 2 + 0.5).clamp(0, 1)
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
image = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
sample = self.numpy_to_pil(sample)
image = self.numpy_to_pil(sample)
return {"sample": sample}
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@ -16,13 +16,13 @@
# and https://github.com/hojonathanho/diffusion
import math
from typing import Optional, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@ -116,7 +116,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
@ -174,7 +176,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = prev_sample + variance
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(
self,

View File

@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from typing import Optional, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@ -135,7 +135,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
sample: Union[torch.FloatTensor, np.ndarray],
predict_epsilon=True,
generator=None,
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
@ -177,7 +179,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
pred_prev_sample = pred_prev_sample + variance
return {"prev_sample": pred_prev_sample}
if not return_dict:
return (pred_prev_sample,)
return SchedulerOutput(prev_sample=pred_prev_sample)
def add_noise(
self,

View File

@ -13,15 +13,34 @@
# limitations under the License.
from typing import Union
from dataclasses import dataclass
from typing import Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class KarrasVeOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Derivate of predicted original image sample (x_0).
"""
prev_sample: torch.FloatTensor
derivative: torch.FloatTensor
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
"""
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
@ -102,12 +121,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sigma_hat: float,
sigma_prev: float,
sample_hat: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
pred_original_sample = sample_hat + sigma_hat * model_output
derivative = (sample_hat - pred_original_sample) / sigma_hat
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
return {"prev_sample": sample_prev, "derivative": derivative}
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
def step_correct(
self,
@ -117,11 +141,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
sample_hat: Union[torch.FloatTensor, np.ndarray],
sample_prev: Union[torch.FloatTensor, np.ndarray],
derivative: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[KarrasVeOutput, Tuple]:
pred_original_sample = sample_prev + sigma_prev * model_output
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
return {"prev_sample": sample_prev, "derivative": derivative_corr}
if not return_dict:
return (sample_prev, derivative)
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
def add_noise(self, original_samples, noise, timesteps):
raise NotImplementedError()

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from typing import Tuple, Union
import numpy as np
import torch
@ -20,7 +20,7 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
@ -100,7 +100,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
order: int = 4,
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
sigma = self.sigmas[timestep]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
@ -121,7 +122,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
)
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(self, original_samples, noise, timesteps):
sigmas = self.match_shape(self.sigmas[timesteps], noise)

View File

@ -15,13 +15,13 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
import math
from typing import Union
from typing import Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@ -133,18 +133,21 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
else:
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
def step_prk(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
@ -176,14 +179,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
self.counter += 1
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def step_plms(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
):
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
@ -226,7 +233,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
self.counter += 1
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf

View File

@ -15,13 +15,32 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
import warnings
from typing import Optional, Union
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput
@dataclass
class SdeVeOutput(BaseOutput):
"""
Output class for the ScoreSdeVeScheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
"""
prev_sample: torch.FloatTensor
prev_sample_mean: torch.FloatTensor
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
@ -117,8 +136,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
timestep: int,
sample: Union[torch.FloatTensor, np.ndarray],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs,
):
) -> Union[SdeVeOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE.
"""
@ -150,15 +170,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
# TODO is the variable diffusion the correct scaling term for the noise?
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
return {"prev_sample": prev_sample, "prev_sample_mean": prev_sample_mean}
if not return_dict:
return (prev_sample, prev_sample_mean)
return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
def step_correct(
self,
model_output: Union[torch.FloatTensor, np.ndarray],
sample: Union[torch.FloatTensor, np.ndarray],
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
**kwargs,
):
) -> Union[SchedulerOutput, Tuple]:
"""
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
after making the prediction for the previous timestep.
@ -186,7 +210,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
return {"prev_sample": prev_sample}
if not return_dict:
return (prev_sample,)
return SchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps

View File

@ -11,15 +11,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Union
import numpy as np
import torch
from ..utils import BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
@dataclass
class SchedulerOutput(BaseOutput):
"""
Base class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class SchedulerMixin:
config_name = SCHEDULER_CONFIG_NAME

View File

@ -33,6 +33,7 @@ from .import_utils import (
requires_backends,
)
from .logging import get_logger
from .outputs import BaseOutput
logger = get_logger(__name__)

View File

@ -0,0 +1,139 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Generic utilities
"""
import warnings
from collections import OrderedDict
from dataclasses import fields
from typing import Any, Tuple
import numpy as np
from .import_utils import is_torch_available
def is_tensor(x):
"""
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
"""
if is_torch_available():
import torch
if isinstance(x, torch.Tensor):
return True
return isinstance(x, np.ndarray)
class BaseOutput(OrderedDict):
"""
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
python dictionary.
<Tip warning={true}>
You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
before.
</Tip>
"""
def __post_init__(self):
class_fields = fields(self)
# Safety and consistency checks
if not len(class_fields):
raise ValueError(f"{self.__class__.__name__} has no fields.")
first_field = getattr(self, class_fields[0].name)
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
if other_fields_are_none and not is_tensor(first_field):
if isinstance(first_field, dict):
iterator = first_field.items()
first_field_iterator = True
else:
try:
iterator = iter(first_field)
first_field_iterator = True
except TypeError:
first_field_iterator = False
# if we provided an iterator as first field and the iterator is a (key, value) iterator
# set the associated fields
if first_field_iterator:
for element in iterator:
if (
not isinstance(element, (list, tuple))
or not len(element) == 2
or not isinstance(element[0], str)
):
break
setattr(self, element[0], element[1])
if element[1] is not None:
self[element[0]] = element[1]
elif first_field is not None:
self[class_fields[0].name] = first_field
else:
for field in class_fields:
v = getattr(self, field.name)
if v is not None:
self[field.name] = v
def __delitem__(self, *args, **kwargs):
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
def setdefault(self, *args, **kwargs):
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
def pop(self, *args, **kwargs):
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
def update(self, *args, **kwargs):
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
def __getitem__(self, k):
if isinstance(k, str):
inner_dict = {k: v for (k, v) in self.items()}
if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
warnings.warn(
"The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
" `'images'` instead.",
DeprecationWarning,
)
return inner_dict["images"]
return inner_dict[k]
else:
return self.to_tuple()[k]
def __setattr__(self, name, value):
if name in self.keys() and value is not None:
# Don't call self.__setitem__ to avoid recursion errors
super().__setitem__(name, value)
super().__setattr__(name, value)
def __setitem__(self, key, value):
# Will raise a KeyException if needed
super().__setitem__(key, value)
# Don't call self.__setattr__ to avoid recursion errors
super().__setattr__(key, value)
def to_tuple(self) -> Tuple[Any]:
"""
Convert self to a tuple containing all the attributes/keys that are not `None`.
"""
return tuple(self[k] for k in self.keys())

View File

@ -15,6 +15,7 @@
import inspect
import tempfile
from typing import Dict, List, Tuple
import numpy as np
import torch
@ -39,12 +40,12 @@ class ModelTesterMixin:
with torch.no_grad():
image = model(**inputs_dict)
if isinstance(image, dict):
image = image["sample"]
image = image.sample
new_image = new_model(**inputs_dict)
if isinstance(new_image, dict):
new_image = new_image["sample"]
new_image = new_image.sample
max_diff = (image - new_image).abs().sum().item()
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
@ -57,11 +58,11 @@ class ModelTesterMixin:
with torch.no_grad():
first = model(**inputs_dict)
if isinstance(first, dict):
first = first["sample"]
first = first.sample
second = model(**inputs_dict)
if isinstance(second, dict):
second = second["sample"]
second = second.sample
out_1 = first.cpu().numpy()
out_2 = second.cpu().numpy()
@ -80,7 +81,7 @@ class ModelTesterMixin:
output = model(**inputs_dict)
if isinstance(output, dict):
output = output["sample"]
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
@ -122,12 +123,12 @@ class ModelTesterMixin:
output_1 = model(**inputs_dict)
if isinstance(output_1, dict):
output_1 = output_1["sample"]
output_1 = output_1.sample
output_2 = new_model(**inputs_dict)
if isinstance(output_2, dict):
output_2 = output_2["sample"]
output_2 = output_2.sample
self.assertEqual(output_1.shape, output_2.shape)
@ -140,7 +141,7 @@ class ModelTesterMixin:
output = model(**inputs_dict)
if isinstance(output, dict):
output = output["sample"]
output = output.sample
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
@ -157,9 +158,47 @@ class ModelTesterMixin:
output = model(**inputs_dict)
if isinstance(output, dict):
output = output["sample"]
output = output.sample
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
loss = torch.nn.functional.mse_loss(output, noise)
loss.backward()
ema_model.step(model)
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
outputs_dict = model(**inputs_dict)
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)

View File

@ -77,7 +77,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
# time_step = torch.tensor([10])
#
# with torch.no_grad():
# output = model(noise, time_step)["sample"]
# output = model(noise, time_step).sample
#
# output_slice = output[0, -1, -3:, -3:].flatten()
# fmt: off
@ -129,7 +129,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.dummy_input)["sample"]
image = model(**self.dummy_input).sample
assert image is not None, "Make sure output is not None"
@ -147,7 +147,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output = model(noise, time_step).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off
@ -258,7 +258,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output = model(noise, time_step).sample
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off
@ -283,7 +283,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
with torch.no_grad():
output = model(noise, time_step)["sample"]
output = model(noise, time_step).sample
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
# fmt: off

View File

@ -87,7 +87,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image, sample_posterior=True)
output = model(image, sample_posterior=True).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off

View File

@ -85,7 +85,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
image = image.to(torch_device)
with torch.no_grad():
output = model(image)
output = model(image).sample
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
# fmt: off

View File

@ -67,12 +67,12 @@ def test_progress_bar(capsys):
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler).to(torch_device)
ddpm(output_type="numpy")["sample"]
ddpm(output_type="numpy").images
captured = capsys.readouterr()
assert "10/10" in captured.err, "Progress bar has to be displayed"
ddpm.set_progress_bar_config(disable=True)
ddpm(output_type="numpy")["sample"]
ddpm(output_type="numpy").images
captured = capsys.readouterr()
assert captured.err == "", "Progress bar should be disabled"
@ -196,15 +196,20 @@ class PipelineFastTests(unittest.TestCase):
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array(
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_pndm_cifar10(self):
unet = self.dummy_uncond_unet
@ -213,14 +218,20 @@ class PipelineFastTests(unittest.TestCase):
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
pndm.to(torch_device)
pndm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy")["sample"]
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy").images
generator = torch.manual_seed(0)
image_from_tuple = pndm(generator=generator, num_inference_steps=20, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_ldm_text2img(self):
unet = self.dummy_cond_unet
@ -239,11 +250,23 @@ class PipelineFastTests(unittest.TestCase):
"sample"
]
generator = torch.manual_seed(0)
image_from_tuple = ldm(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="numpy",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_ddim(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@ -274,16 +297,28 @@ class PipelineFastTests(unittest.TestCase):
sd_pipe.set_progress_bar_config(disable=None)
prompt = "A painting of a squirrel eating a burger"
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output.images
image = output["sample"]
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@ -310,13 +345,25 @@ class PipelineFastTests(unittest.TestCase):
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output["sample"]
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@ -343,13 +390,25 @@ class PipelineFastTests(unittest.TestCase):
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
image = output["sample"]
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_score_sde_ve_pipeline(self):
unet = self.dummy_uncond_unet
@ -360,14 +419,19 @@ class PipelineFastTests(unittest.TestCase):
sde_ve.set_progress_bar_config(disable=None)
torch.manual_seed(0)
image = sde_ve(num_inference_steps=2, output_type="numpy")["sample"]
image = sde_ve(num_inference_steps=2, output_type="numpy").images
torch.manual_seed(0)
image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_ldm_uncond(self):
unet = self.dummy_uncond_unet
@ -379,13 +443,18 @@ class PipelineFastTests(unittest.TestCase):
ldm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images
generator = torch.manual_seed(0)
image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_karras_ve_pipeline(self):
unet = self.dummy_uncond_unet
@ -396,12 +465,18 @@ class PipelineFastTests(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy")["sample"]
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_img2img(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@ -437,13 +512,26 @@ class PipelineFastTests(unittest.TestCase):
init_image=init_image,
)
image = output["sample"]
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_img2img_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@ -479,14 +567,27 @@ class PipelineFastTests(unittest.TestCase):
output_type="np",
init_image=init_image,
)
image = output.images
image = output["sample"]
generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
return_dict=False,
)
image_from_tuple = output[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_inpaint(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
@ -525,13 +626,27 @@ class PipelineFastTests(unittest.TestCase):
mask_image=mask_image,
)
image = output["sample"]
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
num_inference_steps=2,
output_type="np",
init_image=init_image,
mask_image=mask_image,
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
class PipelineTesterMixin(unittest.TestCase):
@ -565,9 +680,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
new_image = new_ddpm(generator=generator, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@ -586,9 +701,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
image = ddpm(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@ -610,9 +725,9 @@ class PipelineTesterMixin(unittest.TestCase):
generator = torch.manual_seed(0)
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"]
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
generator = generator.manual_seed(0)
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
@ -625,17 +740,17 @@ class PipelineTesterMixin(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
images = pipe(generator=generator, output_type="numpy")["sample"]
images = pipe(generator=generator, output_type="numpy").images
assert images.shape == (1, 32, 32, 3)
assert isinstance(images, np.ndarray)
images = pipe(generator=generator, output_type="pil")["sample"]
images = pipe(generator=generator, output_type="pil").images
assert isinstance(images, list)
assert len(images) == 1
assert isinstance(images[0], PIL.Image.Image)
# use PIL by default
images = pipe(generator=generator)["sample"]
images = pipe(generator=generator).images
assert isinstance(images, list)
assert isinstance(images[0], PIL.Image.Image)
@ -652,7 +767,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
@ -672,7 +787,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddpm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddpm(generator=generator, output_type="numpy")["sample"]
image = ddpm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
@ -692,7 +807,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
@ -711,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase):
pndm.to(torch_device)
pndm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pndm(generator=generator, output_type="numpy")["sample"]
image = pndm(generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
@ -745,7 +860,7 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
@ -768,7 +883,7 @@ class PipelineTesterMixin(unittest.TestCase):
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
)
image = output["sample"]
image = output.images
image_slice = image[0, -3:, -3:, -1]
@ -797,7 +912,7 @@ class PipelineTesterMixin(unittest.TestCase):
with torch.autocast("cuda"):
output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
image = output["sample"]
image = output.images
image_slice = image[0, -3:, -3:, -1]
@ -817,7 +932,7 @@ class PipelineTesterMixin(unittest.TestCase):
sde_ve.set_progress_bar_config(disable=None)
torch.manual_seed(0)
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
image = sde_ve(num_inference_steps=300, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
@ -833,7 +948,7 @@ class PipelineTesterMixin(unittest.TestCase):
ldm.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
@ -857,10 +972,10 @@ class PipelineTesterMixin(unittest.TestCase):
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
ddpm_image = ddpm(generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images
# the values aren't exactly equal, but the images look the same visually
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
@ -882,7 +997,7 @@ class PipelineTesterMixin(unittest.TestCase):
ddim.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
generator = torch.manual_seed(0)
ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
@ -903,7 +1018,7 @@ class PipelineTesterMixin(unittest.TestCase):
pipe.set_progress_bar_config(disable=None)
generator = torch.manual_seed(0)
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 256, 256, 3)
@ -974,9 +1089,8 @@ class PipelineTesterMixin(unittest.TestCase):
prompt = "A fantasy landscape, trending on artstation"
generator = torch.Generator(device=torch_device).manual_seed(0)
image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)[
"sample"
][0]
output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)
image = output.images[0]
expected_array = np.array(output_image)
sampled_array = np.array(image)
@ -1008,7 +1122,7 @@ class PipelineTesterMixin(unittest.TestCase):
strength=0.75,
guidance_scale=7.5,
generator=generator,
)["sample"][0]
).images[0]
expected_array = np.array(output_image)
sampled_array = np.array(image)

View File

@ -14,6 +14,7 @@
# limitations under the License.
import tempfile
import unittest
from typing import Dict, List, Tuple
import numpy as np
import torch
@ -85,8 +86,8 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@ -114,9 +115,9 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps
torch.manual_seed(0)
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
torch.manual_seed(0)
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@ -143,9 +144,9 @@ class SchedulerCommonTest(unittest.TestCase):
kwargs["num_inference_steps"] = num_inference_steps
torch.manual_seed(0)
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
torch.manual_seed(0)
new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@ -166,8 +167,8 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
@ -195,11 +196,64 @@ class SchedulerCommonTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
return t
def recursive_check(tuple_object, dict_object):
if isinstance(tuple_object, (List, Tuple)):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif isinstance(tuple_object, Dict):
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
recursive_check(tuple_iterable_value, dict_iterable_value)
elif tuple_object is None:
return
else:
self.assertTrue(
torch.allclose(
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
),
msg=(
"Tuple and dict output are not equal. Difference:"
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
),
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_dict = scheduler.step(residual, 0, sample, **kwargs)
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs)
recursive_check(outputs_tuple, outputs_dict)
class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,)
@ -270,7 +324,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"]
pred_prev_sample = scheduler.step(residual, t, sample).prev_sample
# if t > 0:
# noise = self.dummy_sample_deter
@ -356,7 +410,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
for t in scheduler.timesteps:
residual = model(sample, t)
sample = scheduler.step(residual, t, sample, eta)["prev_sample"]
sample = scheduler.step(residual, t, sample, eta).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
@ -401,13 +455,13 @@ class PNDMSchedulerTest(SchedulerCommonTest):
# copy over dummy past residuals
new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@ -438,13 +492,13 @@ class PNDMSchedulerTest(SchedulerCommonTest):
# copy over dummy past residual (must be after setting timesteps)
new_scheduler.ets = dummy_past_residuals[:]
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
@ -476,12 +530,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler.ets = dummy_past_residuals[:]
scheduler_pt.ets = dummy_past_residuals_pt[:]
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
@ -535,14 +589,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
scheduler.ets = dummy_past_residuals[:]
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)
@ -573,7 +627,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"]
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
@ -587,11 +641,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for i, t in enumerate(scheduler.prk_timesteps):
residual = model(sample, t)
sample = scheduler.step_prk(residual, i, sample)["prev_sample"]
sample = scheduler.step_prk(residual, i, sample).prev_sample
for i, t in enumerate(scheduler.plms_timesteps):
residual = model(sample, t)
sample = scheduler.step_plms(residual, i, sample)["prev_sample"]
sample = scheduler.step_plms(residual, i, sample).prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
@ -664,13 +718,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
@ -689,13 +743,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
@ -732,13 +786,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
for _ in range(scheduler.correct_steps):
with torch.no_grad():
model_output = model(sample, sigma_t)
sample = scheduler.step_correct(model_output, sample, **kwargs)["prev_sample"]
sample = scheduler.step_correct(model_output, sample, **kwargs).prev_sample
with torch.no_grad():
model_output = model(sample, sigma_t)
output = scheduler.step_pred(model_output, t, sample, **kwargs)
sample, _ = output["prev_sample"], output["prev_sample_mean"]
sample, _ = output.prev_sample, output.prev_sample_mean
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
@ -763,8 +817,8 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
kwargs["num_inference_steps"] = num_inference_steps
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs)["prev_sample"]
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs).prev_sample
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs).prev_sample
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)

View File

@ -66,7 +66,7 @@ class TrainingTests(unittest.TestCase):
for i in range(4):
optimizer.zero_grad()
ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i])["sample"]
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i]).sample
loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i])
loss.backward()
optimizer.step()
@ -78,7 +78,7 @@ class TrainingTests(unittest.TestCase):
for i in range(4):
optimizer.zero_grad()
ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
ddim_noise_pred = model(ddim_noisy_images, timesteps[i])["sample"]
ddim_noise_pred = model(ddim_noisy_images, timesteps[i]).sample
loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i])
loss.backward()
optimizer.step()