Refactor progress bar (#242)

* Refactor progress bar of pipeline __call__

* Make any tqdm configs available

* remove init

* add some tests

* remove file

* finish

* make style

* improve progress bar test

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
hysts 2022-08-30 19:30:06 +09:00 committed by GitHub
parent efa773afd2
commit 5e84353eba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 45 additions and 21 deletions

View File

@ -23,6 +23,7 @@ import torch
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from PIL import Image from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin from .configuration_utils import ConfigMixin
from .utils import DIFFUSERS_CACHE, logging from .utils import DIFFUSERS_CACHE, logging
@ -266,3 +267,16 @@ class DiffusionPipeline(ConfigMixin):
pil_images = [Image.fromarray(image) for image in images] pil_images = [Image.fromarray(image) for image in images]
return pil_images return pil_images
def progress_bar(self, iterable):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
return tqdm(iterable, **self._progress_bar_config)
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs

View File

@ -18,8 +18,6 @@ import warnings
import torch import torch
from tqdm.auto import tqdm
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
@ -56,7 +54,7 @@ class DDIMPipeline(DiffusionPipeline):
# set step values # set step values
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output # 1. predict noise model_output
model_output = self.unet(image, t)["sample"] model_output = self.unet(image, t)["sample"]

View File

@ -18,8 +18,6 @@ import warnings
import torch import torch
from tqdm.auto import tqdm
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
@ -53,7 +51,7 @@ class DDPMPipeline(DiffusionPipeline):
# set step values # set step values
self.scheduler.set_timesteps(1000) self.scheduler.set_timesteps(1000)
for t in tqdm(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
# 1. predict noise model_output # 1. predict noise model_output
model_output = self.unet(image, t)["sample"] model_output = self.unet(image, t)["sample"]

View File

@ -6,7 +6,6 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from tqdm.auto import tqdm
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput from transformers.modeling_outputs import BaseModelOutput
@ -83,7 +82,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
if accepts_eta: if accepts_eta:
extra_kwargs["eta"] = eta extra_kwargs["eta"] = eta
for t in tqdm(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale == 1.0: if guidance_scale == 1.0:
# guidance_scale of 1 means no guidance # guidance_scale of 1 means no guidance
latents_input = latents latents_input = latents

View File

@ -3,8 +3,6 @@ import warnings
import torch import torch
from tqdm.auto import tqdm
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
@ -45,7 +43,7 @@ class LDMPipeline(DiffusionPipeline):
if accepts_eta: if accepts_eta:
extra_kwargs["eta"] = eta extra_kwargs["eta"] = eta
for t in tqdm(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
# predict the noise residual # predict the noise residual
noise_prediction = self.unet(latents, t)["sample"] noise_prediction = self.unet(latents, t)["sample"]
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1

View File

@ -18,8 +18,6 @@ import warnings
import torch import torch
from tqdm.auto import tqdm
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
@ -54,7 +52,7 @@ class PNDMPipeline(DiffusionPipeline):
image = image.to(self.device) image = image.to(self.device)
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
model_output = self.unet(image, t)["sample"] model_output = self.unet(image, t)["sample"]
image = self.scheduler.step(model_output, t, image)["prev_sample"] image = self.scheduler.step(model_output, t, image)["prev_sample"]

View File

@ -4,7 +4,6 @@ import warnings
import torch import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from tqdm.auto import tqdm
class ScoreSdeVePipeline(DiffusionPipeline): class ScoreSdeVePipeline(DiffusionPipeline):
@ -37,7 +36,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
self.scheduler.set_sigmas(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps)
for i, t in tqdm(enumerate(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
# correction step # correction step

View File

@ -4,7 +4,6 @@ from typing import List, Optional, Union
import torch import torch
from tqdm.auto import tqdm
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
@ -133,7 +132,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
if accepts_eta: if accepts_eta:
extra_step_kwargs["eta"] = eta extra_step_kwargs["eta"] = eta
for i, t in tqdm(enumerate(self.scheduler.timesteps)): for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler): if isinstance(self.scheduler, LMSDiscreteScheduler):

View File

@ -3,8 +3,6 @@ import warnings
import torch import torch
from tqdm.auto import tqdm
from ...models import UNet2DModel from ...models import UNet2DModel
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ...schedulers import KarrasVeScheduler from ...schedulers import KarrasVeScheduler
@ -53,7 +51,7 @@ class KarrasVePipeline(DiffusionPipeline):
self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_timesteps(num_inference_steps)
for t in tqdm(self.scheduler.timesteps): for t in self.progress_bar(self.scheduler.timesteps):
# here sigma_t == t_i from the paper # here sigma_t == t_i from the paper
sigma = self.scheduler.schedule[t] sigma = self.scheduler.schedule[t]
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0

View File

@ -44,6 +44,29 @@ from diffusers.testing_utils import slow, torch_device
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
def test_progress_bar(capsys):
model = UNet2DModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=3,
out_channels=3,
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
)
scheduler = DDPMScheduler(num_train_timesteps=10)
ddpm = DDPMPipeline(model, scheduler).to(torch_device)
ddpm(output_type="numpy")["sample"]
captured = capsys.readouterr()
assert "10/10" in captured.err, "Progress bar has to be displayed"
ddpm.set_progress_bar_config(disable=True)
ddpm(output_type="numpy")["sample"]
captured = capsys.readouterr()
assert captured.err == "", "Progress bar should be disabled"
class PipelineTesterMixin(unittest.TestCase): class PipelineTesterMixin(unittest.TestCase):
def test_from_pretrained_save_pretrained(self): def test_from_pretrained_save_pretrained(self):
# 1. Load models # 1. Load models