[ModelOutputs] Replace dict outputs with Dict/Dataclass and allow to return tuples (#334)
* add outputs for models * add for pipelines * finish schedulers * better naming * adapt tests as well * replace dict access with . access * make schedulers works * finish * correct readme * make bcp compatible * up * small fix * finish * more fixes * more fixes * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/models/vae.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Adapt model outputs * Apply more suggestions * finish examples * correct Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
parent
daddd98b88
commit
cc59b05635
16
README.md
16
README.md
|
@ -80,7 +80,7 @@ pipe = pipe.to("cuda")
|
|||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt)["sample"][0]
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
**Note**: If you don't want to use the token, you can also simply download the model weights
|
||||
|
@ -101,7 +101,7 @@ pipe = pipe.to("cuda")
|
|||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt)["sample"][0]
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
If you are limited by GPU memory, you might want to consider using the model in `fp16`.
|
||||
|
@ -117,7 +117,7 @@ pipe = pipe.to("cuda")
|
|||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt)["sample"][0]
|
||||
image = pipe(prompt).images[0]
|
||||
```
|
||||
|
||||
Finally, if you wish to use a different scheduler, you can simply instantiate
|
||||
|
@ -143,7 +143,7 @@ pipe = pipe.to("cuda")
|
|||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt)["sample"][0]
|
||||
image = pipe(prompt).images[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
|
@ -184,7 +184,7 @@ init_image = init_image.resize((768, 512))
|
|||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
with autocast("cuda"):
|
||||
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"]
|
||||
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
|
||||
|
||||
images[0].save("fantasy_landscape.png")
|
||||
```
|
||||
|
@ -228,7 +228,7 @@ pipe = pipe.to(device)
|
|||
|
||||
prompt = "a cat sitting on a bench"
|
||||
with autocast("cuda"):
|
||||
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
|
||||
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
|
||||
|
||||
images[0].save("cat_on_bench.png")
|
||||
```
|
||||
|
@ -260,7 +260,7 @@ ldm = DiffusionPipeline.from_pretrained(model_id)
|
|||
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6)["sample"]
|
||||
images = ldm([prompt], num_inference_steps=50, eta=0.3, guidance_scale=6).images
|
||||
|
||||
# save images
|
||||
for idx, image in enumerate(images):
|
||||
|
@ -277,7 +277,7 @@ model_id = "google/ddpm-celebahq-256"
|
|||
ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference
|
||||
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
image = ddpm()["sample"]
|
||||
image = ddpm().images
|
||||
|
||||
# save image
|
||||
image[0].save("ddpm_generated_image.png")
|
||||
|
|
|
@ -76,7 +76,7 @@ pipe = pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch
|
|||
prompt = "A <cat-toy> backpack"
|
||||
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)["sample"][0]
|
||||
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
|
||||
|
||||
image.save("cat-backpack.png")
|
||||
```
|
||||
|
|
|
@ -498,7 +498,7 @@ def main():
|
|||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"]).sample().detach()
|
||||
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
|
@ -515,7 +515,7 @@ def main():
|
|||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states)["sample"]
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
|
|
@ -139,7 +139,7 @@ def main(args):
|
|||
|
||||
with accelerator.accumulate(model):
|
||||
# Predict the noise residual
|
||||
noise_pred = model(noisy_images, timesteps)["sample"]
|
||||
noise_pred = model(noisy_images, timesteps).sample
|
||||
loss = F.mse_loss(noise_pred, noise)
|
||||
accelerator.backward(loss)
|
||||
|
||||
|
@ -174,7 +174,7 @@ def main(args):
|
|||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
|
||||
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy").images
|
||||
|
||||
# denormalize the images and save to tensorboard
|
||||
images_processed = (images * 255).round().astype("uint8")
|
||||
|
|
|
@ -119,7 +119,7 @@ for mod in models:
|
|||
noise = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
time_step = torch.tensor([10] * noise.shape[0])
|
||||
with torch.no_grad():
|
||||
logits = model(noise, time_step)["sample"]
|
||||
logits = model(noise, time_step).sample
|
||||
|
||||
assert torch.allclose(
|
||||
logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3
|
||||
|
|
|
@ -19,9 +19,9 @@ import shutil
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from huggingface_hub import HfFolder, Repository, whoami
|
||||
|
||||
from .pipeline_utils import DiffusionPipeline
|
||||
from .utils import is_modelcards_available, logging
|
||||
|
||||
|
||||
|
|
|
@ -1,14 +1,27 @@
|
|||
from typing import Dict, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Hidden states output. Output of last layer of model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
@ -118,8 +131,11 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.FloatTensor, timestep: Union[torch.Tensor, float, int]
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DOutput, Tuple]:
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
@ -181,6 +197,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|||
timesteps = timesteps.reshape((sample.shape[0], *([1] * len(sample.shape[1:]))))
|
||||
sample = sample / timesteps
|
||||
|
||||
output = {"sample": sample}
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return output
|
||||
return UNet2DOutput(sample=sample)
|
||||
|
|
|
@ -1,14 +1,27 @@
|
|||
from typing import Dict, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNet2DConditionOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
|
@ -125,7 +138,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
) -> Dict[str, torch.FloatTensor]:
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
@ -183,6 +197,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
output = {"sample": sample}
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return output
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -6,9 +7,50 @@ import torch.nn as nn
|
|||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .unet_blocks import UNetMidBlock2D, get_down_block, get_up_block
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecoderOutput(BaseOutput):
|
||||
"""
|
||||
Output of decoding method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Decoded output sample of the model. Output of the last layer of the model.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class VQEncoderOutput(BaseOutput):
|
||||
"""
|
||||
Output of VQModel encoding method.
|
||||
|
||||
Args:
|
||||
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
Encoded output sample of the model. Output of the last layer of the model.
|
||||
"""
|
||||
|
||||
latents: torch.FloatTensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoencoderKLOutput(BaseOutput):
|
||||
"""
|
||||
Output of AutoencoderKL encoding method.
|
||||
|
||||
Args:
|
||||
latent_dist (`DiagonalGaussianDistribution`):
|
||||
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
|
||||
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
|
||||
"""
|
||||
|
||||
latent_dist: "DiagonalGaussianDistribution"
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -369,12 +411,18 @@ class VQModel(ModelMixin, ConfigMixin):
|
|||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
def encode(self, x):
|
||||
def encode(self, x, return_dict: bool = True):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
if not return_dict:
|
||||
return (h,)
|
||||
|
||||
return VQEncoderOutput(latents=h)
|
||||
|
||||
def decode(
|
||||
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
|
@ -382,13 +430,21 @@ class VQModel(ModelMixin, ConfigMixin):
|
|||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
x = sample
|
||||
h = self.encode(x)
|
||||
dec = self.decode(h)
|
||||
return dec
|
||||
h = self.encode(x).latents
|
||||
dec = self.decode(h).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
|
||||
class AutoencoderKL(ModelMixin, ConfigMixin):
|
||||
|
@ -431,23 +487,37 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|||
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
|
||||
def encode(self, x):
|
||||
def encode(self, x, return_dict: bool = True):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, sample: torch.FloatTensor, sample_posterior: bool = False) -> torch.FloatTensor:
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.FloatTensor, sample_posterior: bool = False, return_dict: bool = True
|
||||
) -> Union[DecoderOutput, torch.FloatTensor]:
|
||||
x = sample
|
||||
posterior = self.encode(x)
|
||||
posterior = self.encode(x).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec
|
||||
dec = self.decode(z).sample
|
||||
|
||||
if not return_dict:
|
||||
return (dec,)
|
||||
|
||||
return DecoderOutput(sample=dec)
|
||||
|
|
|
@ -17,16 +17,19 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import PIL
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import DIFFUSERS_CACHE, logging
|
||||
from .utils import DIFFUSERS_CACHE, BaseOutput, logging
|
||||
|
||||
|
||||
INDEX_FILE = "diffusion_pytorch_model.bin"
|
||||
|
@ -54,6 +57,20 @@ for library in LOADABLE_CLASSES:
|
|||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for image pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
config_name = "model_index.json"
|
||||
|
|
|
@ -94,7 +94,7 @@ pipe = pipe.to("cuda")
|
|||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt)["sample"][0]
|
||||
image = pipe(prompt).images[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
|
@ -130,7 +130,7 @@ init_image = init_image.resize((768, 512))
|
|||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
with autocast("cuda"):
|
||||
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5)["sample"]
|
||||
images = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5).images
|
||||
|
||||
images[0].save("fantasy_landscape.png")
|
||||
```
|
||||
|
@ -174,7 +174,7 @@ pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
|||
|
||||
prompt = "a cat sitting on a bench"
|
||||
with autocast("cuda"):
|
||||
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75)["sample"]
|
||||
images = pipe(prompt=prompt, init_image=init_image, mask_image=mask_image, strength=0.75).images
|
||||
|
||||
images[0].save("cat_on_bench.png")
|
||||
```
|
||||
|
|
|
@ -15,10 +15,11 @@
|
|||
|
||||
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
class DDIMPipeline(DiffusionPipeline):
|
||||
|
@ -28,7 +29,16 @@ class DDIMPipeline(DiffusionPipeline):
|
|||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, eta=0.0, num_inference_steps=50, output_type="pil", **kwargs):
|
||||
def __call__(
|
||||
self,
|
||||
batch_size=1,
|
||||
generator=None,
|
||||
eta=0.0,
|
||||
num_inference_steps=50,
|
||||
output_type="pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
|
@ -56,15 +66,18 @@ class DDIMPipeline(DiffusionPipeline):
|
|||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. predict previous mean of image x_t-1 and add variance depending on eta
|
||||
# do x_t -> x_t-1
|
||||
image = self.scheduler.step(model_output, t, image, eta)["prev_sample"]
|
||||
image = self.scheduler.step(model_output, t, image, eta).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
|
|
@ -15,10 +15,11 @@
|
|||
|
||||
|
||||
import warnings
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
class DDPMPipeline(DiffusionPipeline):
|
||||
|
@ -28,7 +29,9 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, generator=None, output_type="pil", **kwargs):
|
||||
def __call__(
|
||||
self, batch_size=1, generator=None, output_type="pil", return_dict: bool = True, **kwargs
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
|
@ -53,14 +56,17 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
# 2. compute previous image: x_t -> t_t-1
|
||||
image = self.scheduler.step(model_output, t, image, generator=generator)["prev_sample"]
|
||||
image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
|
|
@ -12,7 +12,7 @@ from transformers.modeling_outputs import BaseModelOutput
|
|||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.utils import logging
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
class LDMTextToImagePipeline(DiffusionPipeline):
|
||||
|
@ -32,8 +32,9 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
|||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[Tuple, ImagePipelineOutput]:
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if "torch_device" in kwargs:
|
||||
|
@ -95,25 +96,28 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
|||
context = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context)["sample"]
|
||||
noise_pred = self.unet(latents_input, t, encoder_hidden_states=context).sample
|
||||
# perform guidance
|
||||
if guidance_scale != 1.0:
|
||||
noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vqvae.decode(latents)
|
||||
image = self.vqvae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
|
||||
################################################################################
|
||||
|
@ -525,7 +529,7 @@ class LDMBertEncoder(LDMBertPreTrainedModel):
|
|||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
Whether or not to return a [`~utils.BaseModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import inspect
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel, VQModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import DDIMScheduler
|
||||
|
||||
|
||||
|
@ -28,8 +28,9 @@ class LDMPipeline(DiffusionPipeline):
|
|||
eta: float = 0.0,
|
||||
num_inference_steps: int = 50,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[Tuple, ImagePipelineOutput]:
|
||||
# eta corresponds to η in paper and should be between [0, 1]
|
||||
|
||||
if "torch_device" in kwargs:
|
||||
|
@ -61,16 +62,19 @@ class LDMPipeline(DiffusionPipeline):
|
|||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# predict the noise residual
|
||||
noise_prediction = self.unet(latents, t)["sample"]
|
||||
noise_prediction = self.unet(latents, t).sample
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_prediction, t, latents, **extra_kwargs).prev_sample
|
||||
|
||||
# decode the image latents with the VAE
|
||||
image = self.vqvae.decode(latents)
|
||||
image = self.vqvae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
|
|
@ -15,12 +15,12 @@
|
|||
|
||||
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import PNDMScheduler
|
||||
|
||||
|
||||
|
@ -40,8 +40,9 @@ class PNDMPipeline(DiffusionPipeline):
|
|||
num_inference_steps: int = 50,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
# For more information on the sampling method you can take a look at Algorithm 2 of
|
||||
# the official paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
|
||||
|
@ -66,13 +67,16 @@ class PNDMPipeline(DiffusionPipeline):
|
|||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
model_output = self.unet(image, t).sample
|
||||
|
||||
image = self.scheduler.step(model_output, t, image)["prev_sample"]
|
||||
image = self.scheduler.step(model_output, t, image).prev_sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image}
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
|
|
@ -1,12 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import ScoreSdeVeScheduler
|
||||
|
||||
|
||||
|
@ -26,8 +25,9 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
num_inference_steps: int = 2000,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
|
@ -56,18 +56,21 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
|
||||
# correction step
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
model_output = self.unet(sample, sigma_t)["sample"]
|
||||
sample = self.scheduler.step_correct(model_output, sample, generator=generator)["prev_sample"]
|
||||
model_output = self.unet(sample, sigma_t).sample
|
||||
sample = self.scheduler.step_correct(model_output, sample, generator=generator).prev_sample
|
||||
|
||||
# prediction step
|
||||
model_output = model(sample, sigma_t)["sample"]
|
||||
model_output = model(sample, sigma_t).sample
|
||||
output = self.scheduler.step_pred(model_output, t, sample, generator=generator)
|
||||
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
sample, sample_mean = output.prev_sample, output.prev_sample_mean
|
||||
|
||||
sample = sample_mean.clamp(0, 1)
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
sample = self.numpy_to_pil(sample)
|
||||
|
||||
return {"sample": sample}
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return ImagePipelineOutput(images=sample)
|
||||
|
|
|
@ -67,7 +67,7 @@ pipe = pipe.to("cuda")
|
|||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt)["sample"][0]
|
||||
image = pipe(prompt).sample[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
|
@ -89,7 +89,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
|||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt)["sample"][0]
|
||||
image = pipe(prompt).sample[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
|
@ -115,7 +115,7 @@ pipe = StableDiffusionPipeline.from_pretrained(
|
|||
|
||||
prompt = "a photo of an astronaut riding a horse on mars"
|
||||
with autocast("cuda"):
|
||||
image = pipe(prompt)["sample"][0]
|
||||
image = pipe(prompt).sample[0]
|
||||
|
||||
image.save("astronaut_rides_horse.png")
|
||||
```
|
||||
|
|
|
@ -1,5 +1,31 @@
|
|||
# flake8: noqa
|
||||
from ...utils import is_transformers_available
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
import PIL
|
||||
from PIL import Image
|
||||
|
||||
from ...utils import BaseOutput, is_transformers_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class StableDiffusionPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Stable Diffusion pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
nsfw_content_detected (`List[bool]`)
|
||||
List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_content_detected: List[bool]
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
|
|
|
@ -9,6 +9,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
|
@ -47,6 +48,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
if "torch_device" in kwargs:
|
||||
|
@ -141,7 +143,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
@ -150,13 +152,13 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents)
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
@ -168,4 +170,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
|
|
@ -10,6 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
|
@ -57,6 +58,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
|
||||
if isinstance(prompt, str):
|
||||
|
@ -83,7 +85,8 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
init_image = preprocess(init_image)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_latents = self.vae.encode(init_image.to(self.device)).sample(generator=generator)
|
||||
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# expand init_latents for batch_size
|
||||
|
@ -158,7 +161,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
t = t.to(self.unet.dtype)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
@ -167,13 +170,13 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents.to(self.vae.dtype))
|
||||
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
@ -185,4 +188,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
|
|
@ -11,6 +11,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import DDIMScheduler, PNDMScheduler
|
||||
from . import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
|
@ -72,6 +73,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
eta: Optional[float] = 0.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
|
||||
if isinstance(prompt, str):
|
||||
|
@ -98,7 +100,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
init_image = preprocess_image(init_image).to(self.device)
|
||||
|
||||
# encode the init image into latents and scale the latents
|
||||
init_latents = self.vae.encode(init_image).sample(generator=generator)
|
||||
init_latent_dist = self.vae.encode(init_image.to(self.device)).latent_dist
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# Expand init_latents for batch_size
|
||||
|
@ -166,7 +170,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
|
@ -174,7 +178,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs)["prev_sample"]
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, t)
|
||||
|
@ -182,7 +186,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
|
||||
# scale and decode the image latents with vae
|
||||
latents = 1 / 0.18215 * latents
|
||||
image = self.vae.decode(latents)
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
@ -194,4 +198,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return {"sample": image, "nsfw_content_detected": has_nsfw_concept}
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
#!/usr/bin/env python3
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import KarrasVeScheduler
|
||||
|
||||
|
||||
|
@ -35,8 +35,9 @@ class KarrasVePipeline(DiffusionPipeline):
|
|||
num_inference_steps: int = 50,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[Tuple, ImagePipelineOutput]:
|
||||
if "torch_device" in kwargs:
|
||||
device = kwargs.pop("torch_device")
|
||||
warnings.warn(
|
||||
|
@ -71,7 +72,7 @@ class KarrasVePipeline(DiffusionPipeline):
|
|||
|
||||
# 3. Predict the noise residual given the noise magnitude `sigma_hat`
|
||||
# The model inputs and output are adjusted by following eq. (213) in [1].
|
||||
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2)["sample"]
|
||||
model_output = (sigma_hat / 2) * model((sample_hat + 1) / 2, sigma_hat / 2).sample
|
||||
|
||||
# 4. Evaluate dx/dt at sigma_hat
|
||||
# 5. Take Euler step from sigma to sigma_prev
|
||||
|
@ -80,20 +81,23 @@ class KarrasVePipeline(DiffusionPipeline):
|
|||
if sigma_prev != 0:
|
||||
# 6. Apply 2nd order correction
|
||||
# The model inputs and output are adjusted by following eq. (213) in [1].
|
||||
model_output = (sigma_prev / 2) * model((step_output["prev_sample"] + 1) / 2, sigma_prev / 2)["sample"]
|
||||
model_output = (sigma_prev / 2) * model((step_output.prev_sample + 1) / 2, sigma_prev / 2).sample
|
||||
step_output = self.scheduler.step_correct(
|
||||
model_output,
|
||||
sigma_hat,
|
||||
sigma_prev,
|
||||
sample_hat,
|
||||
step_output["prev_sample"],
|
||||
step_output.prev_sample,
|
||||
step_output["derivative"],
|
||||
)
|
||||
sample = step_output["prev_sample"]
|
||||
sample = step_output.prev_sample
|
||||
|
||||
sample = (sample / 2 + 0.5).clamp(0, 1)
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
image = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
sample = self.numpy_to_pil(sample)
|
||||
image = self.numpy_to_pil(sample)
|
||||
|
||||
return {"sample": sample}
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
|
|
@ -16,13 +16,13 @@
|
|||
# and https://github.com/hojonathanho/diffusion
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
|
@ -116,7 +116,9 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
):
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
|
@ -174,7 +176,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
prev_sample = prev_sample + variance
|
||||
|
||||
return {"prev_sample": prev_sample}
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
|
|
|
@ -15,13 +15,13 @@
|
|||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
|
@ -135,7 +135,9 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
predict_epsilon=True,
|
||||
generator=None,
|
||||
):
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
|
||||
t = timestep
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]:
|
||||
|
@ -177,7 +179,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
return {"prev_sample": pred_prev_sample}
|
||||
if not return_dict:
|
||||
return (pred_prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=pred_prev_sample)
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
|
|
|
@ -13,15 +13,34 @@
|
|||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class KarrasVeOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Derivate of predicted original image sample (x_0).
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
derivative: torch.FloatTensor
|
||||
|
||||
|
||||
class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and
|
||||
|
@ -102,12 +121,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
sigma_hat: float,
|
||||
sigma_prev: float,
|
||||
sample_hat: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
return_dict: bool = True,
|
||||
) -> Union[KarrasVeOutput, Tuple]:
|
||||
|
||||
pred_original_sample = sample_hat + sigma_hat * model_output
|
||||
derivative = (sample_hat - pred_original_sample) / sigma_hat
|
||||
sample_prev = sample_hat + (sigma_prev - sigma_hat) * derivative
|
||||
|
||||
return {"prev_sample": sample_prev, "derivative": derivative}
|
||||
if not return_dict:
|
||||
return (sample_prev, derivative)
|
||||
|
||||
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
|
||||
|
||||
def step_correct(
|
||||
self,
|
||||
|
@ -117,11 +141,17 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
sample_hat: Union[torch.FloatTensor, np.ndarray],
|
||||
sample_prev: Union[torch.FloatTensor, np.ndarray],
|
||||
derivative: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
return_dict: bool = True,
|
||||
) -> Union[KarrasVeOutput, Tuple]:
|
||||
|
||||
pred_original_sample = sample_prev + sigma_prev * model_output
|
||||
derivative_corr = (sample_prev - pred_original_sample) / sigma_prev
|
||||
sample_prev = sample_hat + (sigma_prev - sigma_hat) * (0.5 * derivative + 0.5 * derivative_corr)
|
||||
return {"prev_sample": sample_prev, "derivative": derivative_corr}
|
||||
|
||||
if not return_dict:
|
||||
return (sample_prev, derivative)
|
||||
|
||||
return KarrasVeOutput(prev_sample=sample_prev, derivative=derivative)
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -20,7 +20,7 @@ import torch
|
|||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
@ -100,7 +100,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
order: int = 4,
|
||||
):
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
sigma = self.sigmas[timestep]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
|
@ -121,7 +122,10 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
coeff * derivative for coeff, derivative in zip(lms_coeffs, reversed(self.derivatives))
|
||||
)
|
||||
|
||||
return {"prev_sample": prev_sample}
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
sigmas = self.match_shape(self.sigmas[timesteps], noise)
|
||||
|
|
|
@ -15,13 +15,13 @@
|
|||
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
|
||||
|
||||
import math
|
||||
from typing import Union
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
|
@ -133,18 +133,21 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
|
||||
if self.counter < len(self.prk_timesteps) and not self.config.skip_prk_steps:
|
||||
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample)
|
||||
return self.step_prk(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
|
||||
else:
|
||||
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample)
|
||||
return self.step_plms(model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict)
|
||||
|
||||
def step_prk(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
|
||||
solution to the differential equation.
|
||||
|
@ -176,14 +179,18 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
|
||||
self.counter += 1
|
||||
|
||||
return {"prev_sample": prev_sample}
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def step_plms(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
):
|
||||
return_dict: bool = True,
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
||||
times to approximate the solution.
|
||||
|
@ -226,7 +233,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
|
||||
self.counter += 1
|
||||
|
||||
return {"prev_sample": prev_sample}
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def _get_prev_sample(self, sample, timestep, timestep_prev, model_output):
|
||||
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
||||
|
|
|
@ -15,13 +15,32 @@
|
|||
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Union
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class SdeVeOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the ScoreSdeVeScheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
prev_sample_mean (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Mean averaged `prev_sample`. Same as `prev_sample`, only mean-averaged over previous timesteps.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
prev_sample_mean: torch.FloatTensor
|
||||
|
||||
|
||||
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
@ -117,8 +136,9 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
timestep: int,
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[SdeVeOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE.
|
||||
"""
|
||||
|
@ -150,15 +170,19 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
# TODO is the variable diffusion the correct scaling term for the noise?
|
||||
prev_sample = prev_sample_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
|
||||
|
||||
return {"prev_sample": prev_sample, "prev_sample_mean": prev_sample_mean}
|
||||
if not return_dict:
|
||||
return (prev_sample, prev_sample_mean)
|
||||
|
||||
return SdeVeOutput(prev_sample=prev_sample, prev_sample_mean=prev_sample_mean)
|
||||
|
||||
def step_correct(
|
||||
self,
|
||||
model_output: Union[torch.FloatTensor, np.ndarray],
|
||||
sample: Union[torch.FloatTensor, np.ndarray],
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Union[SchedulerOutput, Tuple]:
|
||||
"""
|
||||
Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
|
||||
after making the prediction for the previous timestep.
|
||||
|
@ -186,7 +210,10 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
prev_sample_mean = sample + step_size[:, None, None, None] * model_output
|
||||
prev_sample = prev_sample_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
|
||||
|
||||
return {"prev_sample": prev_sample}
|
||||
if not return_dict:
|
||||
return (prev_sample,)
|
||||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
|
|
@ -11,15 +11,32 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..utils import BaseOutput
|
||||
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Base class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
|
||||
|
||||
class SchedulerMixin:
|
||||
|
||||
config_name = SCHEDULER_CONFIG_NAME
|
||||
|
|
|
@ -33,6 +33,7 @@ from .import_utils import (
|
|||
requires_backends,
|
||||
)
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Generic utilities
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import fields
|
||||
from typing import Any, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .import_utils import is_torch_available
|
||||
|
||||
|
||||
def is_tensor(x):
|
||||
"""
|
||||
Tests if `x` is a `torch.Tensor` or `np.ndarray`.
|
||||
"""
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if isinstance(x, torch.Tensor):
|
||||
return True
|
||||
|
||||
return isinstance(x, np.ndarray)
|
||||
|
||||
|
||||
class BaseOutput(OrderedDict):
|
||||
"""
|
||||
Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
|
||||
tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
|
||||
python dictionary.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
You can't unpack a `BaseOutput` directly. Use the [`~utils.BaseOutput.to_tuple`] method to convert it to a tuple
|
||||
before.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
class_fields = fields(self)
|
||||
|
||||
# Safety and consistency checks
|
||||
if not len(class_fields):
|
||||
raise ValueError(f"{self.__class__.__name__} has no fields.")
|
||||
|
||||
first_field = getattr(self, class_fields[0].name)
|
||||
other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
|
||||
|
||||
if other_fields_are_none and not is_tensor(first_field):
|
||||
if isinstance(first_field, dict):
|
||||
iterator = first_field.items()
|
||||
first_field_iterator = True
|
||||
else:
|
||||
try:
|
||||
iterator = iter(first_field)
|
||||
first_field_iterator = True
|
||||
except TypeError:
|
||||
first_field_iterator = False
|
||||
|
||||
# if we provided an iterator as first field and the iterator is a (key, value) iterator
|
||||
# set the associated fields
|
||||
if first_field_iterator:
|
||||
for element in iterator:
|
||||
if (
|
||||
not isinstance(element, (list, tuple))
|
||||
or not len(element) == 2
|
||||
or not isinstance(element[0], str)
|
||||
):
|
||||
break
|
||||
setattr(self, element[0], element[1])
|
||||
if element[1] is not None:
|
||||
self[element[0]] = element[1]
|
||||
elif first_field is not None:
|
||||
self[class_fields[0].name] = first_field
|
||||
else:
|
||||
for field in class_fields:
|
||||
v = getattr(self, field.name)
|
||||
if v is not None:
|
||||
self[field.name] = v
|
||||
|
||||
def __delitem__(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def setdefault(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def pop(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
||||
|
||||
def __getitem__(self, k):
|
||||
if isinstance(k, str):
|
||||
inner_dict = {k: v for (k, v) in self.items()}
|
||||
if self.__class__.__name__ in ["StableDiffusionPipelineOutput", "ImagePipelineOutput"] and k == "sample":
|
||||
warnings.warn(
|
||||
"The keyword 'samples' is deprecated and will be removed in version 0.4.0. Please use `.images` or"
|
||||
" `'images'` instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return inner_dict["images"]
|
||||
return inner_dict[k]
|
||||
else:
|
||||
return self.to_tuple()[k]
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name in self.keys() and value is not None:
|
||||
# Don't call self.__setitem__ to avoid recursion errors
|
||||
super().__setitem__(name, value)
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# Will raise a KeyException if needed
|
||||
super().__setitem__(key, value)
|
||||
# Don't call self.__setattr__ to avoid recursion errors
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def to_tuple(self) -> Tuple[Any]:
|
||||
"""
|
||||
Convert self to a tuple containing all the attributes/keys that are not `None`.
|
||||
"""
|
||||
return tuple(self[k] for k in self.keys())
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
import inspect
|
||||
import tempfile
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -39,12 +40,12 @@ class ModelTesterMixin:
|
|||
with torch.no_grad():
|
||||
image = model(**inputs_dict)
|
||||
if isinstance(image, dict):
|
||||
image = image["sample"]
|
||||
image = image.sample
|
||||
|
||||
new_image = new_model(**inputs_dict)
|
||||
|
||||
if isinstance(new_image, dict):
|
||||
new_image = new_image["sample"]
|
||||
new_image = new_image.sample
|
||||
|
||||
max_diff = (image - new_image).abs().sum().item()
|
||||
self.assertLessEqual(max_diff, 5e-5, "Models give different forward passes")
|
||||
|
@ -57,11 +58,11 @@ class ModelTesterMixin:
|
|||
with torch.no_grad():
|
||||
first = model(**inputs_dict)
|
||||
if isinstance(first, dict):
|
||||
first = first["sample"]
|
||||
first = first.sample
|
||||
|
||||
second = model(**inputs_dict)
|
||||
if isinstance(second, dict):
|
||||
second = second["sample"]
|
||||
second = second.sample
|
||||
|
||||
out_1 = first.cpu().numpy()
|
||||
out_2 = second.cpu().numpy()
|
||||
|
@ -80,7 +81,7 @@ class ModelTesterMixin:
|
|||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output["sample"]
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
|
@ -122,12 +123,12 @@ class ModelTesterMixin:
|
|||
output_1 = model(**inputs_dict)
|
||||
|
||||
if isinstance(output_1, dict):
|
||||
output_1 = output_1["sample"]
|
||||
output_1 = output_1.sample
|
||||
|
||||
output_2 = new_model(**inputs_dict)
|
||||
|
||||
if isinstance(output_2, dict):
|
||||
output_2 = output_2["sample"]
|
||||
output_2 = output_2.sample
|
||||
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
|
||||
|
@ -140,7 +141,7 @@ class ModelTesterMixin:
|
|||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output["sample"]
|
||||
output = output.sample
|
||||
|
||||
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
|
@ -157,9 +158,47 @@ class ModelTesterMixin:
|
|||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output["sample"]
|
||||
output = output.sample
|
||||
|
||||
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
|
||||
loss = torch.nn.functional.mse_loss(output, noise)
|
||||
loss.backward()
|
||||
ema_model.step(model)
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
outputs_dict = model(**inputs_dict)
|
||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
|
|
@ -77,7 +77,7 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
# time_step = torch.tensor([10])
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# output = model(noise, time_step)["sample"]
|
||||
# output = model(noise, time_step).sample
|
||||
#
|
||||
# output_slice = output[0, -1, -3:, -3:].flatten()
|
||||
# fmt: off
|
||||
|
@ -129,7 +129,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.dummy_input)["sample"]
|
||||
image = model(**self.dummy_input).sample
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
|
@ -147,7 +147,7 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
time_step = torch.tensor([10] * noise.shape[0]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)["sample"]
|
||||
output = model(noise, time_step).sample
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
|
||||
# fmt: off
|
||||
|
@ -258,7 +258,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)["sample"]
|
||||
output = model(noise, time_step).sample
|
||||
|
||||
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
|
||||
# fmt: off
|
||||
|
@ -283,7 +283,7 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
time_step = torch.tensor(batch_size * [1e-4]).to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(noise, time_step)["sample"]
|
||||
output = model(noise, time_step).sample
|
||||
|
||||
output_slice = output[0, -3:, -3:, -1].flatten().cpu()
|
||||
# fmt: off
|
||||
|
|
|
@ -87,7 +87,7 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
|||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
image = image.to(torch_device)
|
||||
with torch.no_grad():
|
||||
output = model(image, sample_posterior=True)
|
||||
output = model(image, sample_posterior=True).sample
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
|
||||
# fmt: off
|
||||
|
|
|
@ -85,7 +85,7 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
image = torch.randn(1, model.config.in_channels, model.config.sample_size, model.config.sample_size)
|
||||
image = image.to(torch_device)
|
||||
with torch.no_grad():
|
||||
output = model(image)
|
||||
output = model(image).sample
|
||||
|
||||
output_slice = output[0, -1, -3:, -3:].flatten().cpu()
|
||||
# fmt: off
|
||||
|
|
|
@ -67,12 +67,12 @@ def test_progress_bar(capsys):
|
|||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||
|
||||
ddpm = DDPMPipeline(model, scheduler).to(torch_device)
|
||||
ddpm(output_type="numpy")["sample"]
|
||||
ddpm(output_type="numpy").images
|
||||
captured = capsys.readouterr()
|
||||
assert "10/10" in captured.err, "Progress bar has to be displayed"
|
||||
|
||||
ddpm.set_progress_bar_config(disable=True)
|
||||
ddpm(output_type="numpy")["sample"]
|
||||
ddpm(output_type="numpy").images
|
||||
captured = capsys.readouterr()
|
||||
assert captured.err == "", "Progress bar should be disabled"
|
||||
|
||||
|
@ -196,15 +196,20 @@ class PipelineFastTests(unittest.TestCase):
|
|||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
|
||||
image = ddpm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = ddpm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array(
|
||||
[1.000e00, 5.717e-01, 4.717e-01, 1.000e00, 0.000e00, 1.000e00, 3.000e-04, 0.000e00, 9.000e-04]
|
||||
)
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_pndm_cifar10(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
|
@ -213,14 +218,20 @@ class PipelineFastTests(unittest.TestCase):
|
|||
pndm = PNDMPipeline(unet=unet, scheduler=scheduler)
|
||||
pndm.to(torch_device)
|
||||
pndm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy")["sample"]
|
||||
image = pndm(generator=generator, num_inference_steps=20, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = pndm(generator=generator, num_inference_steps=20, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([1.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_ldm_text2img(self):
|
||||
unet = self.dummy_cond_unet
|
||||
|
@ -239,11 +250,23 @@ class PipelineFastTests(unittest.TestCase):
|
|||
"sample"
|
||||
]
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = ldm(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="numpy",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.5074, 0.5026, 0.4998, 0.4056, 0.3523, 0.4649, 0.5289, 0.5299, 0.4897])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_ddim(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
@ -274,16 +297,28 @@ class PipelineFastTests(unittest.TestCase):
|
|||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
image = output.images
|
||||
|
||||
image = output["sample"]
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.5112, 0.4692, 0.4715, 0.5206, 0.4894, 0.5114, 0.5096, 0.4932, 0.4755])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_pndm(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
@ -310,13 +345,25 @@ class PipelineFastTests(unittest.TestCase):
|
|||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
|
||||
image = output["sample"]
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.4937, 0.4649, 0.4716, 0.5145, 0.4889, 0.513, 0.513, 0.4905, 0.4738])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_k_lms(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
@ -343,13 +390,25 @@ class PipelineFastTests(unittest.TestCase):
|
|||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="np")
|
||||
|
||||
image = output["sample"]
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
expected_slice = np.array([0.5067, 0.4689, 0.4614, 0.5233, 0.4903, 0.5112, 0.524, 0.5069, 0.4785])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
|
@ -360,14 +419,19 @@ class PipelineFastTests(unittest.TestCase):
|
|||
sde_ve.set_progress_bar_config(disable=None)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=2, output_type="numpy")["sample"]
|
||||
image = sde_ve(num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
torch.manual_seed(0)
|
||||
image_from_tuple = sde_ve(num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
|
||||
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_ldm_uncond(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
|
@ -379,13 +443,18 @@ class PipelineFastTests(unittest.TestCase):
|
|||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy")["sample"]
|
||||
image = ldm(generator=generator, num_inference_steps=2, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = ldm(generator=generator, num_inference_steps=2, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
expected_slice = np.array([0.8512, 0.818, 0.6411, 0.6808, 0.4465, 0.5618, 0.46, 0.6231, 0.5172])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_karras_ve_pipeline(self):
|
||||
unet = self.dummy_uncond_unet
|
||||
|
@ -396,12 +465,18 @@ class PipelineFastTests(unittest.TestCase):
|
|||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy")["sample"]
|
||||
image = pipe(num_inference_steps=2, generator=generator, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image_from_tuple = pipe(num_inference_steps=2, generator=generator, output_type="numpy", return_dict=False)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_img2img(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
@ -437,13 +512,26 @@ class PipelineFastTests(unittest.TestCase):
|
|||
init_image=init_image,
|
||||
)
|
||||
|
||||
image = output["sample"]
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4492, 0.3865, 0.4222, 0.5854, 0.5139, 0.4379, 0.4193, 0.48, 0.4218])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_img2img_k_lms(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
@ -479,14 +567,27 @@ class PipelineFastTests(unittest.TestCase):
|
|||
output_type="np",
|
||||
init_image=init_image,
|
||||
)
|
||||
image = output.images
|
||||
|
||||
image = output["sample"]
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
return_dict=False,
|
||||
)
|
||||
image_from_tuple = output[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4367, 0.4986, 0.4372, 0.6706, 0.5665, 0.444, 0.5864, 0.6019, 0.5203])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_inpaint(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
@ -525,13 +626,27 @@ class PipelineFastTests(unittest.TestCase):
|
|||
mask_image=mask_image,
|
||||
)
|
||||
|
||||
image = output["sample"]
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = sd_pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 32, 32, 3)
|
||||
expected_slice = np.array([0.4731, 0.5346, 0.4531, 0.6251, 0.5446, 0.4057, 0.5527, 0.5896, 0.5153])
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
|
@ -565,9 +680,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = new_ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
new_image = new_ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
|
@ -586,9 +701,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
|
||||
new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
|
@ -610,9 +725,9 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy")["sample"]
|
||||
image = ddpm_from_hub_custom_model(generator=generator, output_type="numpy").images
|
||||
generator = generator.manual_seed(0)
|
||||
new_image = ddpm_from_hub(generator=generator, output_type="numpy")["sample"]
|
||||
new_image = ddpm_from_hub(generator=generator, output_type="numpy").images
|
||||
|
||||
assert np.abs(image - new_image).sum() < 1e-5, "Models don't give the same forward pass"
|
||||
|
||||
|
@ -625,17 +740,17 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
images = pipe(generator=generator, output_type="numpy")["sample"]
|
||||
images = pipe(generator=generator, output_type="numpy").images
|
||||
assert images.shape == (1, 32, 32, 3)
|
||||
assert isinstance(images, np.ndarray)
|
||||
|
||||
images = pipe(generator=generator, output_type="pil")["sample"]
|
||||
images = pipe(generator=generator, output_type="pil").images
|
||||
assert isinstance(images, list)
|
||||
assert len(images) == 1
|
||||
assert isinstance(images[0], PIL.Image.Image)
|
||||
|
||||
# use PIL by default
|
||||
images = pipe(generator=generator)["sample"]
|
||||
images = pipe(generator=generator).images
|
||||
assert isinstance(images, list)
|
||||
assert isinstance(images[0], PIL.Image.Image)
|
||||
|
||||
|
@ -652,7 +767,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
|
@ -672,7 +787,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
ddpm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
|
@ -692,7 +807,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
ddim.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ddim(generator=generator, eta=0.0, output_type="numpy")["sample"]
|
||||
image = ddim(generator=generator, eta=0.0, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
|
@ -711,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
pndm.to(torch_device)
|
||||
pndm.set_progress_bar_config(disable=None)
|
||||
generator = torch.manual_seed(0)
|
||||
image = pndm(generator=generator, output_type="numpy")["sample"]
|
||||
image = pndm(generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
|
@ -745,7 +860,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy")["sample"]
|
||||
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
|
@ -768,7 +883,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
|
||||
)
|
||||
|
||||
image = output["sample"]
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
|
@ -797,7 +912,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
with torch.autocast("cuda"):
|
||||
output = sd_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
|
||||
image = output["sample"]
|
||||
image = output.images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
|
@ -817,7 +932,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
sde_ve.set_progress_bar_config(disable=None)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
|
||||
image = sde_ve(num_inference_steps=300, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
|
@ -833,7 +948,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
ldm.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
|
||||
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
|
@ -857,10 +972,10 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
ddim.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddpm_image = ddpm(generator=generator, output_type="numpy")["sample"]
|
||||
ddpm_image = ddpm(generator=generator, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")["sample"]
|
||||
ddim_image = ddim(generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy").images
|
||||
|
||||
# the values aren't exactly equal, but the images look the same visually
|
||||
assert np.abs(ddpm_image - ddim_image).max() < 1e-1
|
||||
|
@ -882,7 +997,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
ddim.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy")["sample"]
|
||||
ddpm_images = ddpm(batch_size=4, generator=generator, output_type="numpy").images
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
ddim_images = ddim(batch_size=4, generator=generator, num_inference_steps=1000, eta=1.0, output_type="numpy")[
|
||||
|
@ -903,7 +1018,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy")["sample"]
|
||||
image = pipe(num_inference_steps=20, generator=generator, output_type="numpy").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
|
@ -974,9 +1089,8 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
prompt = "A fantasy landscape, trending on artstation"
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)[
|
||||
"sample"
|
||||
][0]
|
||||
output = pipe(prompt=prompt, init_image=init_image, strength=0.75, guidance_scale=7.5, generator=generator)
|
||||
image = output.images[0]
|
||||
|
||||
expected_array = np.array(output_image)
|
||||
sampled_array = np.array(image)
|
||||
|
@ -1008,7 +1122,7 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
strength=0.75,
|
||||
guidance_scale=7.5,
|
||||
generator=generator,
|
||||
)["sample"][0]
|
||||
).images[0]
|
||||
|
||||
expected_array = np.array(output_image)
|
||||
sampled_array = np.array(image)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -85,8 +86,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
|
@ -114,9 +115,9 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
torch.manual_seed(0)
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
torch.manual_seed(0)
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
|
@ -143,9 +144,9 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
torch.manual_seed(0)
|
||||
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
|
||||
torch.manual_seed(0)
|
||||
new_output = new_scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step(residual, 1, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
|
@ -166,8 +167,8 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step(residual, 0, sample, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_0 = scheduler.step(residual, 0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step(residual, 1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
@ -195,11 +196,64 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output = scheduler.step(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
|
||||
output = scheduler.step(residual, 1, sample, **kwargs).prev_sample
|
||||
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt, **kwargs).prev_sample
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
def test_scheduler_outputs_equivalence(self):
|
||||
def set_nan_tensor_to_zero(t):
|
||||
t[t != t] = 0
|
||||
return t
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif isinstance(tuple_object, Dict):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object.values(), dict_object.values()):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5
|
||||
),
|
||||
msg=(
|
||||
"Tuple and dict output are not equal. Difference:"
|
||||
f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
|
||||
f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
|
||||
f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
|
||||
),
|
||||
)
|
||||
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_dict = scheduler.step(residual, 0, sample, **kwargs)
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
outputs_tuple = scheduler.step(residual, 0, sample, return_dict=False, **kwargs)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
|
||||
class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DDPMScheduler,)
|
||||
|
@ -270,7 +324,7 @@ class DDPMSchedulerTest(SchedulerCommonTest):
|
|||
residual = model(sample, t)
|
||||
|
||||
# 2. predict previous mean of sample x_t-1
|
||||
pred_prev_sample = scheduler.step(residual, t, sample)["prev_sample"]
|
||||
pred_prev_sample = scheduler.step(residual, t, sample).prev_sample
|
||||
|
||||
# if t > 0:
|
||||
# noise = self.dummy_sample_deter
|
||||
|
@ -356,7 +410,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
|||
for t in scheduler.timesteps:
|
||||
residual = model(sample, t)
|
||||
|
||||
sample = scheduler.step(residual, t, sample, eta)["prev_sample"]
|
||||
sample = scheduler.step(residual, t, sample, eta).prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
@ -401,13 +455,13 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
# copy over dummy past residuals
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
|
@ -438,13 +492,13 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
# copy over dummy past residual (must be after setting timesteps)
|
||||
new_scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step_prk(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step_plms(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
|
@ -476,12 +530,12 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
scheduler.ets = dummy_past_residuals[:]
|
||||
scheduler_pt.ets = dummy_past_residuals_pt[:]
|
||||
|
||||
output = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
|
||||
output_pt = scheduler_pt.step_prk(residual_pt, 1, sample_pt, **kwargs).prev_sample
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
|
||||
output_pt = scheduler_pt.step_plms(residual_pt, 1, sample_pt, **kwargs).prev_sample
|
||||
|
||||
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"
|
||||
|
||||
|
@ -535,14 +589,14 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
dummy_past_residuals = [residual + 0.2, residual + 0.15, residual + 0.1, residual + 0.05]
|
||||
scheduler.ets = dummy_past_residuals[:]
|
||||
|
||||
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_0 = scheduler.step_prk(residual, 0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step_prk(residual, 1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
||||
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_0 = scheduler.step_plms(residual, 0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step_plms(residual, 1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
@ -573,7 +627,7 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample)["prev_sample"]
|
||||
scheduler.step_plms(self.dummy_sample, 1, self.dummy_sample).prev_sample
|
||||
|
||||
def test_full_loop_no_noise(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
|
@ -587,11 +641,11 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
|
||||
for i, t in enumerate(scheduler.prk_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_prk(residual, i, sample)["prev_sample"]
|
||||
sample = scheduler.step_prk(residual, i, sample).prev_sample
|
||||
|
||||
for i, t in enumerate(scheduler.plms_timesteps):
|
||||
residual = model(sample, t)
|
||||
sample = scheduler.step_plms(residual, i, sample)["prev_sample"]
|
||||
sample = scheduler.step_plms(residual, i, sample).prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
@ -664,13 +718,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
|||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
|
||||
|
||||
|
@ -689,13 +743,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
|||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
|
||||
output = scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step_pred(residual, time_step, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"
|
||||
|
||||
output = scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
|
||||
new_output = new_scheduler.step_correct(residual, sample, **kwargs)["prev_sample"]
|
||||
output = scheduler.step_correct(residual, sample, **kwargs).prev_sample
|
||||
new_output = new_scheduler.step_correct(residual, sample, **kwargs).prev_sample
|
||||
|
||||
assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler correction are not identical"
|
||||
|
||||
|
@ -732,13 +786,13 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
|||
for _ in range(scheduler.correct_steps):
|
||||
with torch.no_grad():
|
||||
model_output = model(sample, sigma_t)
|
||||
sample = scheduler.step_correct(model_output, sample, **kwargs)["prev_sample"]
|
||||
sample = scheduler.step_correct(model_output, sample, **kwargs).prev_sample
|
||||
|
||||
with torch.no_grad():
|
||||
model_output = model(sample, sigma_t)
|
||||
|
||||
output = scheduler.step_pred(model_output, t, sample, **kwargs)
|
||||
sample, _ = output["prev_sample"], output["prev_sample_mean"]
|
||||
sample, _ = output.prev_sample, output.prev_sample_mean
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
@ -763,8 +817,8 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase):
|
|||
elif num_inference_steps is not None and not hasattr(scheduler, "set_timesteps"):
|
||||
kwargs["num_inference_steps"] = num_inference_steps
|
||||
|
||||
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs)["prev_sample"]
|
||||
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs)["prev_sample"]
|
||||
output_0 = scheduler.step_pred(residual, 0, sample, **kwargs).prev_sample
|
||||
output_1 = scheduler.step_pred(residual, 1, sample, **kwargs).prev_sample
|
||||
|
||||
self.assertEqual(output_0.shape, sample.shape)
|
||||
self.assertEqual(output_0.shape, output_1.shape)
|
||||
|
|
|
@ -66,7 +66,7 @@ class TrainingTests(unittest.TestCase):
|
|||
for i in range(4):
|
||||
optimizer.zero_grad()
|
||||
ddpm_noisy_images = ddpm_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
|
||||
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i])["sample"]
|
||||
ddpm_noise_pred = model(ddpm_noisy_images, timesteps[i]).sample
|
||||
loss = torch.nn.functional.mse_loss(ddpm_noise_pred, noise[i])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
@ -78,7 +78,7 @@ class TrainingTests(unittest.TestCase):
|
|||
for i in range(4):
|
||||
optimizer.zero_grad()
|
||||
ddim_noisy_images = ddim_scheduler.add_noise(clean_images[i], noise[i], timesteps[i])
|
||||
ddim_noise_pred = model(ddim_noisy_images, timesteps[i])["sample"]
|
||||
ddim_noise_pred = model(ddim_noisy_images, timesteps[i]).sample
|
||||
loss = torch.nn.functional.mse_loss(ddim_noise_pred, noise[i])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
|
Loading…
Reference in New Issue