Refactor progress bar (#242)
* Refactor progress bar of pipeline __call__ * Make any tqdm configs available * remove init * add some tests * remove file * finish * make style * improve progress bar test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
efa773afd2
commit
5e84353eba
|
@ -23,6 +23,7 @@ import torch
|
|||
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .configuration_utils import ConfigMixin
|
||||
from .utils import DIFFUSERS_CACHE, logging
|
||||
|
@ -266,3 +267,16 @@ class DiffusionPipeline(ConfigMixin):
|
|||
pil_images = [Image.fromarray(image) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
def progress_bar(self, iterable):
|
||||
if not hasattr(self, "_progress_bar_config"):
|
||||
self._progress_bar_config = {}
|
||||
elif not isinstance(self._progress_bar_config, dict):
|
||||
raise ValueError(
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
|
||||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
|
|
@ -18,8 +18,6 @@ import warnings
|
|||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
|
@ -56,7 +54,7 @@ class DDIMPipeline(DiffusionPipeline):
|
|||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
|
|
|
@ -18,8 +18,6 @@ import warnings
|
|||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
|
@ -53,7 +51,7 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
# set step values
|
||||
self.scheduler.set_timesteps(1000)
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
|
@ -83,7 +82,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
|||
if accepts_eta:
|
||||
extra_kwargs["eta"] = eta
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
if guidance_scale == 1.0:
|
||||
# guidance_scale of 1 means no guidance
|
||||
latents_input = latents
|
||||
|
|
|
@ -3,8 +3,6 @@ import warnings
|
|||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
|
@ -45,7 +43,7 @@ class LDMPipeline(DiffusionPipeline):
|
|||
if accepts_eta:
|
||||
extra_kwargs["eta"] = eta
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# predict the noise residual
|
||||
noise_prediction = self.unet(latents, t)["sample"]
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
|
|
@ -18,8 +18,6 @@ import warnings
|
|||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
|
@ -54,7 +52,7 @@ class PNDMPipeline(DiffusionPipeline):
|
|||
image = image.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
model_output = self.unet(image, t)["sample"]
|
||||
|
||||
image = self.scheduler.step(model_output, t, image)["prev_sample"]
|
||||
|
|
|
@ -4,7 +4,6 @@ import warnings
|
|||
import torch
|
||||
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
|
@ -37,7 +36,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.set_sigmas(num_inference_steps)
|
||||
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
|
||||
|
||||
# correction step
|
||||
|
|
|
@ -4,7 +4,6 @@ from typing import List, Optional, Union
|
|||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
|
@ -133,7 +132,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
||||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
|
|
|
@ -3,8 +3,6 @@ import warnings
|
|||
|
||||
import torch
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...models import UNet2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import KarrasVeScheduler
|
||||
|
@ -53,7 +51,7 @@ class KarrasVePipeline(DiffusionPipeline):
|
|||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
# here sigma_t == t_i from the paper
|
||||
sigma = self.scheduler.schedule[t]
|
||||
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
|
||||
|
|
|
@ -44,6 +44,29 @@ from diffusers.testing_utils import slow, torch_device
|
|||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
def test_progress_bar(capsys):
|
||||
model = UNet2DModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||
)
|
||||
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||
|
||||
ddpm = DDPMPipeline(model, scheduler).to(torch_device)
|
||||
ddpm(output_type="numpy")["sample"]
|
||||
captured = capsys.readouterr()
|
||||
assert "10/10" in captured.err, "Progress bar has to be displayed"
|
||||
|
||||
ddpm.set_progress_bar_config(disable=True)
|
||||
ddpm(output_type="numpy")["sample"]
|
||||
captured = capsys.readouterr()
|
||||
assert captured.err == "", "Progress bar should be disabled"
|
||||
|
||||
|
||||
class PipelineTesterMixin(unittest.TestCase):
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
# 1. Load models
|
||||
|
|
Loading…
Reference in New Issue