Make tqdm calls notebook-compatible
This commit is contained in:
parent
ffe7b93b60
commit
1820024005
|
@ -16,7 +16,7 @@
|
|||
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
@ -44,7 +44,7 @@ class DDIMPipeline(DiffusionPipeline):
|
|||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm.tqdm(self.scheduler.timesteps):
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
# 1. predict noise model_output
|
||||
with torch.no_grad():
|
||||
model_output = self.unet(image, t)
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
@ -41,7 +41,7 @@ class DDPMPipeline(DiffusionPipeline):
|
|||
image = image.to(torch_device)
|
||||
|
||||
num_prediction_steps = len(self.scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
for t in tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
# 1. predict noise model_output
|
||||
with torch.no_grad():
|
||||
model_output = self.unet(image, t)
|
||||
|
|
|
@ -22,7 +22,7 @@ import torch
|
|||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
|
||||
import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPVisionConfig, GPT2Tokenizer
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
|
@ -778,7 +778,7 @@ class GlidePipeline(DiffusionPipeline):
|
|||
|
||||
# 3. Run the text2image generation step
|
||||
num_prediction_steps = len(self.text_scheduler)
|
||||
for t in tqdm.tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
for t in tqdm(reversed(range(num_prediction_steps)), total=num_prediction_steps):
|
||||
with torch.no_grad():
|
||||
time_input = torch.tensor([t] * image.shape[0], device=torch_device)
|
||||
model_output = text_model_fn(image, time_input, transformer_out)
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
|
@ -599,7 +599,7 @@ class LatentDiffusionPipeline(DiffusionPipeline):
|
|||
# - eta -> η
|
||||
# - pred_image_direction -> "direction pointingc to x_t"
|
||||
# - pred_prev_image -> "x_t-1"
|
||||
for t in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
for t in tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
|
||||
# guidance_scale of 1 means no guidance
|
||||
if guidance_scale == 1.0:
|
||||
image_in = image
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
@ -35,7 +35,7 @@ class LatentDiffusionUncondPipeline(DiffusionPipeline):
|
|||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in tqdm.tqdm(self.scheduler.timesteps):
|
||||
for t in tqdm(self.scheduler.timesteps):
|
||||
with torch.no_grad():
|
||||
model_output = self.unet(image, t)
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
import torch
|
||||
|
||||
import tqdm
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
@ -43,7 +43,7 @@ class PNDMPipeline(DiffusionPipeline):
|
|||
image = image.to(torch_device)
|
||||
|
||||
prk_time_steps = self.scheduler.get_prk_time_steps(num_inference_steps)
|
||||
for t in tqdm.tqdm(range(len(prk_time_steps))):
|
||||
for t in tqdm(range(len(prk_time_steps))):
|
||||
t_orig = prk_time_steps[t]
|
||||
model_output = self.unet(image, t_orig)
|
||||
|
||||
|
|
Loading…
Reference in New Issue