Refactor progress bar (#242)
* Refactor progress bar of pipeline __call__ * Make any tqdm configs available * remove init * add some tests * remove file * finish * make style * improve progress bar test Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
efa773afd2
commit
5e84353eba
|
@ -23,6 +23,7 @@ import torch
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from .configuration_utils import ConfigMixin
|
from .configuration_utils import ConfigMixin
|
||||||
from .utils import DIFFUSERS_CACHE, logging
|
from .utils import DIFFUSERS_CACHE, logging
|
||||||
|
@ -266,3 +267,16 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
pil_images = [Image.fromarray(image) for image in images]
|
pil_images = [Image.fromarray(image) for image in images]
|
||||||
|
|
||||||
return pil_images
|
return pil_images
|
||||||
|
|
||||||
|
def progress_bar(self, iterable):
|
||||||
|
if not hasattr(self, "_progress_bar_config"):
|
||||||
|
self._progress_bar_config = {}
|
||||||
|
elif not isinstance(self._progress_bar_config, dict):
|
||||||
|
raise ValueError(
|
||||||
|
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
return tqdm(iterable, **self._progress_bar_config)
|
||||||
|
|
||||||
|
def set_progress_bar_config(self, **kwargs):
|
||||||
|
self._progress_bar_config = kwargs
|
||||||
|
|
|
@ -18,8 +18,6 @@ import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
from ...pipeline_utils import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,7 +54,7 @@ class DDIMPipeline(DiffusionPipeline):
|
||||||
# set step values
|
# set step values
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
for t in tqdm(self.scheduler.timesteps):
|
for t in self.progress_bar(self.scheduler.timesteps):
|
||||||
# 1. predict noise model_output
|
# 1. predict noise model_output
|
||||||
model_output = self.unet(image, t)["sample"]
|
model_output = self.unet(image, t)["sample"]
|
||||||
|
|
||||||
|
|
|
@ -18,8 +18,6 @@ import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
from ...pipeline_utils import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,7 +51,7 @@ class DDPMPipeline(DiffusionPipeline):
|
||||||
# set step values
|
# set step values
|
||||||
self.scheduler.set_timesteps(1000)
|
self.scheduler.set_timesteps(1000)
|
||||||
|
|
||||||
for t in tqdm(self.scheduler.timesteps):
|
for t in self.progress_bar(self.scheduler.timesteps):
|
||||||
# 1. predict noise model_output
|
# 1. predict noise model_output
|
||||||
model_output = self.unet(image, t)["sample"]
|
model_output = self.unet(image, t)["sample"]
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,6 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from transformers.modeling_outputs import BaseModelOutput
|
from transformers.modeling_outputs import BaseModelOutput
|
||||||
|
@ -83,7 +82,7 @@ class LDMTextToImagePipeline(DiffusionPipeline):
|
||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_kwargs["eta"] = eta
|
extra_kwargs["eta"] = eta
|
||||||
|
|
||||||
for t in tqdm(self.scheduler.timesteps):
|
for t in self.progress_bar(self.scheduler.timesteps):
|
||||||
if guidance_scale == 1.0:
|
if guidance_scale == 1.0:
|
||||||
# guidance_scale of 1 means no guidance
|
# guidance_scale of 1 means no guidance
|
||||||
latents_input = latents
|
latents_input = latents
|
||||||
|
|
|
@ -3,8 +3,6 @@ import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
from ...pipeline_utils import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,7 +43,7 @@ class LDMPipeline(DiffusionPipeline):
|
||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_kwargs["eta"] = eta
|
extra_kwargs["eta"] = eta
|
||||||
|
|
||||||
for t in tqdm(self.scheduler.timesteps):
|
for t in self.progress_bar(self.scheduler.timesteps):
|
||||||
# predict the noise residual
|
# predict the noise residual
|
||||||
noise_prediction = self.unet(latents, t)["sample"]
|
noise_prediction = self.unet(latents, t)["sample"]
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
|
|
@ -18,8 +18,6 @@ import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
from ...pipeline_utils import DiffusionPipeline
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,7 +52,7 @@ class PNDMPipeline(DiffusionPipeline):
|
||||||
image = image.to(self.device)
|
image = image.to(self.device)
|
||||||
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
for t in tqdm(self.scheduler.timesteps):
|
for t in self.progress_bar(self.scheduler.timesteps):
|
||||||
model_output = self.unet(image, t)["sample"]
|
model_output = self.unet(image, t)["sample"]
|
||||||
|
|
||||||
image = self.scheduler.step(model_output, t, image)["prev_sample"]
|
image = self.scheduler.step(model_output, t, image)["prev_sample"]
|
||||||
|
|
|
@ -4,7 +4,6 @@ import warnings
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
|
|
||||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||||
|
@ -37,7 +36,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
self.scheduler.set_sigmas(num_inference_steps)
|
self.scheduler.set_sigmas(num_inference_steps)
|
||||||
|
|
||||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||||
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
|
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=self.device)
|
||||||
|
|
||||||
# correction step
|
# correction step
|
||||||
|
|
|
@ -4,7 +4,6 @@ from typing import List, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||||
|
@ -133,7 +132,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||||
if accepts_eta:
|
if accepts_eta:
|
||||||
extra_step_kwargs["eta"] = eta
|
extra_step_kwargs["eta"] = eta
|
||||||
|
|
||||||
for i, t in tqdm(enumerate(self.scheduler.timesteps)):
|
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||||
# expand the latents if we are doing classifier free guidance
|
# expand the latents if we are doing classifier free guidance
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||||
|
|
|
@ -3,8 +3,6 @@ import warnings
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tqdm.auto import tqdm
|
|
||||||
|
|
||||||
from ...models import UNet2DModel
|
from ...models import UNet2DModel
|
||||||
from ...pipeline_utils import DiffusionPipeline
|
from ...pipeline_utils import DiffusionPipeline
|
||||||
from ...schedulers import KarrasVeScheduler
|
from ...schedulers import KarrasVeScheduler
|
||||||
|
@ -53,7 +51,7 @@ class KarrasVePipeline(DiffusionPipeline):
|
||||||
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps)
|
self.scheduler.set_timesteps(num_inference_steps)
|
||||||
|
|
||||||
for t in tqdm(self.scheduler.timesteps):
|
for t in self.progress_bar(self.scheduler.timesteps):
|
||||||
# here sigma_t == t_i from the paper
|
# here sigma_t == t_i from the paper
|
||||||
sigma = self.scheduler.schedule[t]
|
sigma = self.scheduler.schedule[t]
|
||||||
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
|
sigma_prev = self.scheduler.schedule[t - 1] if t > 0 else 0
|
||||||
|
|
|
@ -44,6 +44,29 @@ from diffusers.testing_utils import slow, torch_device
|
||||||
torch.backends.cuda.matmul.allow_tf32 = False
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
|
||||||
|
|
||||||
|
def test_progress_bar(capsys):
|
||||||
|
model = UNet2DModel(
|
||||||
|
block_out_channels=(32, 64),
|
||||||
|
layers_per_block=2,
|
||||||
|
sample_size=32,
|
||||||
|
in_channels=3,
|
||||||
|
out_channels=3,
|
||||||
|
down_block_types=("DownBlock2D", "AttnDownBlock2D"),
|
||||||
|
up_block_types=("AttnUpBlock2D", "UpBlock2D"),
|
||||||
|
)
|
||||||
|
scheduler = DDPMScheduler(num_train_timesteps=10)
|
||||||
|
|
||||||
|
ddpm = DDPMPipeline(model, scheduler).to(torch_device)
|
||||||
|
ddpm(output_type="numpy")["sample"]
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert "10/10" in captured.err, "Progress bar has to be displayed"
|
||||||
|
|
||||||
|
ddpm.set_progress_bar_config(disable=True)
|
||||||
|
ddpm(output_type="numpy")["sample"]
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
assert captured.err == "", "Progress bar should be disabled"
|
||||||
|
|
||||||
|
|
||||||
class PipelineTesterMixin(unittest.TestCase):
|
class PipelineTesterMixin(unittest.TestCase):
|
||||||
def test_from_pretrained_save_pretrained(self):
|
def test_from_pretrained_save_pretrained(self):
|
||||||
# 1. Load models
|
# 1. Load models
|
||||||
|
|
Loading…
Reference in New Issue