Refactor progress bar (#242)

* Refactor progress bar of pipeline __call__

* Make any tqdm configs available

* remove init

* add some tests

* remove file

* finish

* make style

* improve progress bar test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
hysts 2022-08-30 19:30:06 +09:00 committed by GitHub
parent efa773afd2
commit 5e84353eba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 45 additions and 21 deletions

View File

@ -23,6 +23,7 @@ import torch
from huggingface_hub import snapshot_download
from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging
@ -266,3 +267,16 @@ class DiffusionPipeline(ConfigMixin):
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def progress_bar(self, iterable):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
return tqdm(iterable, **self._progress_bar_config)
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs

View File

@ -18,8 +18,6 @@ import warnings
import torch
from tqdm.auto import tqdm
from ...pipeline_utils import DiffusionPipeline
@ -56,7 +54,7 @@ class DDIMPipeline(DiffusionPipeline):
# set step values
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t)["sample"]

View File

@ -18,8 +18,6 @@ import warnings
import torch
from tqdm.auto import tqdm
from ...pipeline_utils import DiffusionPipeline
@ -53,7 +51,7 @@ class DDPMPipeline(DiffusionPipeline):
# set step values
self.scheduler.set_timesteps(1000)
for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output
model_output = self.unet(image, t)["sample"]

View File

@ -6,7 +6,6 @@ import torch
import torch.nn as nn
import torch.utils.checkpoint
from tqdm.auto import tqdm
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
@ -83,7 +82,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
if accepts_eta:
extra_kwargs["eta"] = eta
for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance
latents_input = latents

View File

@ -3,8 +3,6 @@ import warnings
import torch
from tqdm.auto import tqdm
from ...pipeline_utils import DiffusionPipeline
@ -45,7 +43,7 @@ class LDMPipeline(DiffusionPipeline):
if accepts_eta:
extra_kwargs["eta"] = eta
for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
# predict the noise residual
noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1

View File

@ -18,8 +18,6 @@ import warnings
import torch
from tqdm.auto import tqdm
from ...pipeline_utils import DiffusionPipeline
@ -54,7 +52,7 @@ class PNDMPipeline(DiffusionPipeline):
image = image.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
model_output = self.unet(image, t)["sample"]
image = self.scheduler.step(model_output, t, image)["prev_sample"]

View File

@ -4,7 +4,6 @@ import warnings
import torch
from diffusers import DiffusionPipeline
from tqdm.auto import tqdm
class ScoreSdeVePipeline(DiffusionPipeline):
@ -37,7 +36,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps)
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
# correction step

View File

@ -4,7 +4,6 @@ from typing import List, Optional, Union
import torch
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
@ -133,7 +132,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if accepts_eta:
extra_step_kwargs["eta"] = eta
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):

View File

@ -3,8 +3,6 @@ import warnings
import torch
from tqdm.auto import tqdm
from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import KarrasVeScheduler
@ -53,7 +51,7 @@ class KarrasVePipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps):
for t in self.progress_bar(self.scheduler.timesteps):
# here sigma_t == t_i from the paper
sigma = self.scheduler.schedule[t]
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0

View File

@ -44,6 +44,29 @@ from diffusers.testing_utils import slow, torch_device
torch.backends.cuda.matmul.allow_tf32 = False
def test_progress_bar(capsys):
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler).to(torch_device)
ddpm(output_type="numpy")["sample"]
captured = capsys.readouterr()
assert "10/10" in captured.err, "Progress bar has to be displayed"
ddpm.set_progress_bar_config(disable=True)
ddpm(output_type="numpy")["sample"]
captured = capsys.readouterr()
assert captured.err == "", "Progress bar should be disabled"
class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self):
# 1. Load models