Fix logspam and live previews
This commit is contained in:
parent
1253199889
commit
21880eb9e5
|
@ -19,9 +19,10 @@ class UniPCSampler(object):
|
|||
attr = attr.to(torch.device("cuda"))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def set_hooks(self, before, after):
|
||||
self.before_sample = before
|
||||
self.after_sample = after
|
||||
def set_hooks(self, before_sample, after_sample, after_update):
|
||||
self.before_sample = before_sample
|
||||
self.after_sample = after_sample
|
||||
self.after_update = after_update
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
|
@ -50,9 +51,17 @@ class UniPCSampler(object):
|
|||
):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list): ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
||||
|
@ -60,6 +69,7 @@ class UniPCSampler(object):
|
|||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for UniPC sampling is {size}, eta {eta}')
|
||||
|
||||
device = self.model.betas.device
|
||||
if x_T is None:
|
||||
|
@ -79,7 +89,7 @@ class UniPCSampler(object):
|
|||
guidance_scale=unconditional_guidance_scale,
|
||||
)
|
||||
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample)
|
||||
uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
|
||||
x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)
|
||||
|
||||
return x.to(device), None
|
||||
|
|
|
@ -378,7 +378,8 @@ class UniPC:
|
|||
condition=None,
|
||||
unconditional_condition=None,
|
||||
before_sample=None,
|
||||
after_sample=None
|
||||
after_sample=None,
|
||||
after_update=None
|
||||
):
|
||||
"""Construct a UniPC.
|
||||
|
||||
|
@ -394,6 +395,7 @@ class UniPC:
|
|||
self.unconditional_condition = unconditional_condition
|
||||
self.before_sample = before_sample
|
||||
self.after_sample = after_sample
|
||||
self.after_update = after_update
|
||||
|
||||
def dynamic_thresholding_fn(self, x0, t=None):
|
||||
"""
|
||||
|
@ -434,15 +436,6 @@ class UniPC:
|
|||
noise = self.noise_prediction_fn(x, t)
|
||||
dims = x.dim()
|
||||
alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
|
||||
from pprint import pp
|
||||
print("X:")
|
||||
pp(x)
|
||||
print("sigma_t:")
|
||||
pp(sigma_t)
|
||||
print("noise:")
|
||||
pp(noise)
|
||||
print("alpha_t:")
|
||||
pp(alpha_t)
|
||||
x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
|
||||
if self.thresholding:
|
||||
p = 0.995 # A hyperparameter in the paper of "Imagen" [1].
|
||||
|
@ -524,7 +517,7 @@ class UniPC:
|
|||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
|
||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
#print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
|
||||
|
@ -568,7 +561,7 @@ class UniPC:
|
|||
A_p = C_inv_p
|
||||
|
||||
if use_corrector:
|
||||
print('using corrector')
|
||||
#print('using corrector')
|
||||
C_inv = torch.linalg.inv(C)
|
||||
A_c = C_inv
|
||||
|
||||
|
@ -627,7 +620,7 @@ class UniPC:
|
|||
return x_t, model_t
|
||||
|
||||
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
||||
print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
||||
#print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
dims = x.dim()
|
||||
|
@ -695,7 +688,7 @@ class UniPC:
|
|||
D1s = None
|
||||
|
||||
if use_corrector:
|
||||
print('using corrector')
|
||||
#print('using corrector')
|
||||
# for order 1, we use a simplified version
|
||||
if order == 1:
|
||||
rhos_c = torch.tensor([0.5], device=b.device)
|
||||
|
@ -755,8 +748,9 @@ class UniPC:
|
|||
t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
device = x.device
|
||||
if method == 'multistep':
|
||||
assert steps >= order
|
||||
assert steps >= order, "UniPC order must be < sampling steps"
|
||||
timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
||||
print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps")
|
||||
assert timesteps.shape[0] - 1 == steps
|
||||
with torch.no_grad():
|
||||
vec_t = timesteps[0].expand((x.shape[0]))
|
||||
|
@ -768,6 +762,8 @@ class UniPC:
|
|||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
||||
if model_x is None:
|
||||
model_x = self.model_fn(x, vec_t)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
model_prev_list.append(model_x)
|
||||
t_prev_list.append(vec_t)
|
||||
for step in range(order, steps + 1):
|
||||
|
@ -776,13 +772,15 @@ class UniPC:
|
|||
step_order = min(order, steps + 1 - step)
|
||||
else:
|
||||
step_order = order
|
||||
print('this step order:', step_order)
|
||||
#print('this step order:', step_order)
|
||||
if step == steps:
|
||||
print('do not run corrector at the last step')
|
||||
#print('do not run corrector at the last step')
|
||||
use_corrector = False
|
||||
else:
|
||||
use_corrector = True
|
||||
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
||||
if self.after_update is not None:
|
||||
self.after_update(x, model_x)
|
||||
for i in range(order - 1):
|
||||
t_prev_list[i] = t_prev_list[i + 1]
|
||||
model_prev_list[i] = model_prev_list[i + 1]
|
||||
|
|
|
@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler:
|
|||
|
||||
return x, ts, cond, unconditional_conditioning
|
||||
|
||||
def after_sample(self, x, ts, cond, uncond, res):
|
||||
if self.is_unipc:
|
||||
# unipc model_fn returns (pred_x0)
|
||||
# p_sample_ddim returns (x_prev, pred_x0)
|
||||
res = (None, res[0])
|
||||
|
||||
def update_step(self, last_latent):
|
||||
if self.mask is not None:
|
||||
self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
|
||||
self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
|
||||
else:
|
||||
self.last_latent = res[1]
|
||||
self.last_latent = last_latent
|
||||
|
||||
sd_samplers_common.store_latent(self.last_latent)
|
||||
|
||||
|
@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler:
|
|||
state.sampling_step = self.step
|
||||
shared.total_tqdm.update()
|
||||
|
||||
def after_sample(self, x, ts, cond, uncond, res):
|
||||
if not self.is_unipc:
|
||||
self.update_step(res[1])
|
||||
|
||||
return x, ts, cond, uncond, res
|
||||
|
||||
def unipc_after_update(self, x, model_x):
|
||||
self.update_step(x)
|
||||
|
||||
def initialize(self, p):
|
||||
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
|
||||
if self.eta != 0.0:
|
||||
|
@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler:
|
|||
if hasattr(self.sampler, fieldname):
|
||||
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
|
||||
if self.is_unipc:
|
||||
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r))
|
||||
self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
|
||||
|
||||
self.mask = p.mask if hasattr(p, 'mask') else None
|
||||
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|
||||
|
|
Loading…
Reference in New Issue