diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index f27789f0..214133bc 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -23,6 +23,7 @@ import torch from huggingface_hub import snapshot_download from PIL import Image +from tqdm.auto import tqdm from .configuration_utils import ConfigMixin from .utils import DIFFUSERS_CACHE, logging @@ -266,3 +267,16 @@ class DiffusionPipeline(ConfigMixin): pil_images = [Image.fromarray(image) for image in images] return pil_images + + def progress_bar(self, iterable): + if not hasattr(self, "_progress_bar_config"): + self._progress_bar_config = {} + elif not isinstance(self._progress_bar_config, dict): + raise ValueError( + f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." + ) + + return tqdm(iterable, **self._progress_bar_config) + + def set_progress_bar_config(self, **kwargs): + self._progress_bar_config = kwargs diff --git a/src/diffusers/pipelines/ddim/pipeline_ddim.py b/src/diffusers/pipelines/ddim/pipeline_ddim.py index 700e2b9c..03a9c52b 100644 --- a/src/diffusers/pipelines/ddim/pipeline_ddim.py +++ b/src/diffusers/pipelines/ddim/pipeline_ddim.py @@ -18,8 +18,6 @@ import warnings import torch -from tqdm.auto import tqdm - from ...pipeline_utils import DiffusionPipeline @@ -56,7 +54,7 @@ class DDIMPipeline(DiffusionPipeline): # set step values self.scheduler.set_timesteps(num_inference_steps) - for t in tqdm(self.scheduler.timesteps): + for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t)["sample"] diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index 099add5d..27c156de 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -18,8 +18,6 @@ import warnings import torch -from tqdm.auto import tqdm - from ...pipeline_utils import DiffusionPipeline @@ -53,7 +51,7 @@ class DDPMPipeline(DiffusionPipeline): # set step values self.scheduler.set_timesteps(1000) - for t in tqdm(self.scheduler.timesteps): + for t in self.progress_bar(self.scheduler.timesteps): # 1. predict noise model_output model_output = self.unet(image, t)["sample"] diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 17a15aca..a348d9c0 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from tqdm.auto import tqdm from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from transformers.modeling_outputs import BaseModelOutput @@ -83,7 +82,7 @@ class LDMTextToImagePipeline(DiffusionPipeline): if accepts_eta: extra_kwargs["eta"] = eta - for t in tqdm(self.scheduler.timesteps): + for t in self.progress_bar(self.scheduler.timesteps): if guidance_scale == 1.0: # guidance_scale of 1 means no guidance latents_input = latents diff --git a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py index bdff4fc9..ed9bd09c 100644 --- a/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +++ b/src/diffusers/pipelines/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py @@ -3,8 +3,6 @@ import warnings import torch -from tqdm.auto import tqdm - from ...pipeline_utils import DiffusionPipeline @@ -45,7 +43,7 @@ class LDMPipeline(DiffusionPipeline): if accepts_eta: extra_kwargs["eta"] = eta - for t in tqdm(self.scheduler.timesteps): + for t in self.progress_bar(self.scheduler.timesteps): # predict the noise residual noise_prediction = self.unet(latents, t)["sample"] # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/pndm/pipeline_pndm.py b/src/diffusers/pipelines/pndm/pipeline_pndm.py index bc0f7564..32ddbd8c 100644 --- a/src/diffusers/pipelines/pndm/pipeline_pndm.py +++ b/src/diffusers/pipelines/pndm/pipeline_pndm.py @@ -18,8 +18,6 @@ import warnings import torch -from tqdm.auto import tqdm - from ...pipeline_utils import DiffusionPipeline @@ -54,7 +52,7 @@ class PNDMPipeline(DiffusionPipeline): image = image.to(self.device) self.scheduler.set_timesteps(num_inference_steps) - for t in tqdm(self.scheduler.timesteps): + for t in self.progress_bar(self.scheduler.timesteps): model_output = self.unet(image, t)["sample"] image = self.scheduler.step(model_output, t, image)["prev_sample"] diff --git a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py index 884f1894..7d72ddf7 100644 --- a/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +++ b/src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py @@ -4,7 +4,6 @@ import warnings import torch from diffusers import DiffusionPipeline -from tqdm.auto import tqdm class ScoreSdeVePipeline(DiffusionPipeline): @@ -37,7 +36,7 @@ class ScoreSdeVePipeline(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) self.scheduler.set_sigmas(num_inference_steps) - for i, t in tqdm(enumerate(self.scheduler.timesteps)): + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device) # correction step diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index fca87151..d4290da6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -4,7 +4,6 @@ from typing import List, Optional, Union import torch -from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel @@ -133,7 +132,7 @@ class StableDiffusionPipeline(DiffusionPipeline): if accepts_eta: extra_step_kwargs["eta"] = eta - for i, t in tqdm(enumerate(self.scheduler.timesteps)): + for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents if isinstance(self.scheduler, LMSDiscreteScheduler): diff --git a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py index ebf95e66..97027299 100644 --- a/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py +++ b/src/diffusers/pipelines/stochatic_karras_ve/pipeline_stochastic_karras_ve.py @@ -3,8 +3,6 @@ import warnings import torch -from tqdm.auto import tqdm - from ...models import UNet2DModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import KarrasVeScheduler @@ -53,7 +51,7 @@ class KarrasVePipeline(DiffusionPipeline): self.scheduler.set_timesteps(num_inference_steps) - for t in tqdm(self.scheduler.timesteps): + for t in self.progress_bar(self.scheduler.timesteps): # here sigma_t == t_i from the paper sigma = self.scheduler.schedule[t] sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0 diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 9b6e0896..96d5995f 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -44,6 +44,29 @@ from diffusers.testing_utils import slow, torch_device torch.backends.cuda.matmul.allow_tf32 = False +def test_progress_bar(capsys): + model = UNet2DModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=3, + out_channels=3, + down_block_types=("DownBlock2D", "AttnDownBlock2D"), + up_block_types=("AttnUpBlock2D", "UpBlock2D"), + ) + scheduler = DDPMScheduler(num_train_timesteps=10) + + ddpm = DDPMPipeline(model, scheduler).to(torch_device) + ddpm(output_type="numpy")["sample"] + captured = capsys.readouterr() + assert "10/10" in captured.err, "Progress bar has to be displayed" + + ddpm.set_progress_bar_config(disable=True) + ddpm(output_type="numpy")["sample"] + captured = capsys.readouterr() + assert captured.err == "", "Progress bar should be disabled" + + class PipelineTesterMixin(unittest.TestCase): def test_from_pretrained_save_pretrained(self): # 1. Load models