UniPC progress bar adjustment

This commit is contained in:
Sakura-Luna 2023-05-11 12:26:04 +08:00
parent 22bcc7be42
commit ae17e97898
1 changed files with 37 additions and 33 deletions

View File

@ -1,7 +1,7 @@
import torch
import torch.nn.functional as F
import math
from tqdm.auto import trange
import tqdm
class NoiseScheduleVP:
@ -757,6 +757,7 @@ class UniPC:
vec_t = timesteps[0].expand((x.shape[0]))
model_prev_list = [self.model_fn(x, vec_t)]
t_prev_list = [vec_t]
with tqdm.tqdm(total=steps) as pbar:
# Init the first `order` values by lower order multistep DPM-Solver.
for init_order in range(1, order):
vec_t = timesteps[init_order].expand(x.shape[0])
@ -767,7 +768,9 @@ class UniPC:
self.after_update(x, model_x)
model_prev_list.append(model_x)
t_prev_list.append(vec_t)
for step in trange(order, steps + 1):
pbar.update()
for step in range(order, steps + 1):
vec_t = timesteps[step].expand(x.shape[0])
if lower_order_final:
step_order = min(order, steps + 1 - step)
@ -791,6 +794,7 @@ class UniPC:
if model_x is None:
model_x = self.model_fn(x, vec_t)
model_prev_list[-1] = model_x
pbar.update()
else:
raise NotImplementedError()
if denoise_to_zero: