repair DDIM/PLMS/UniPC batches
This commit is contained in:
parent
007ecfbb29
commit
aeb76ef174
|
@ -51,10 +51,9 @@ class CFGDenoiserTimesteps(CFGDenoiser):
|
||||||
self.alphas = shared.sd_model.alphas_cumprod
|
self.alphas = shared.sd_model.alphas_cumprod
|
||||||
|
|
||||||
def get_pred_x0(self, x_in, x_out, sigma):
|
def get_pred_x0(self, x_in, x_out, sigma):
|
||||||
ts = int(sigma.item())
|
ts = sigma.to(dtype=int)
|
||||||
|
|
||||||
s_in = x_in.new_ones([x_in.shape[0]])
|
a_t = self.alphas[ts][:, None, None, None]
|
||||||
a_t = self.alphas[ts].item() * s_in
|
|
||||||
sqrt_one_minus_at = (1 - a_t).sqrt()
|
sqrt_one_minus_at = (1 - a_t).sqrt()
|
||||||
|
|
||||||
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|
pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()
|
||||||
|
|
|
@ -16,16 +16,17 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
|
||||||
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones((x.shape[0]))
|
||||||
|
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||||
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
|
||||||
index = len(timesteps) - 1 - i
|
index = len(timesteps) - 1 - i
|
||||||
|
|
||||||
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
e_t = model(x, timesteps[index].item() * s_in, **extra_args)
|
||||||
|
|
||||||
a_t = alphas[index].item() * s_in
|
a_t = alphas[index].item() * s_x
|
||||||
a_prev = alphas_prev[index].item() * s_in
|
a_prev = alphas_prev[index].item() * s_x
|
||||||
sigma_t = sigmas[index].item() * s_in
|
sigma_t = sigmas[index].item() * s_x
|
||||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||||
|
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
|
||||||
|
@ -47,13 +48,14 @@ def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
|
||||||
|
|
||||||
extra_args = {} if extra_args is None else extra_args
|
extra_args = {} if extra_args is None else extra_args
|
||||||
s_in = x.new_ones([x.shape[0]])
|
s_in = x.new_ones([x.shape[0]])
|
||||||
|
s_x = x.new_ones((x.shape[0], 1, 1, 1))
|
||||||
old_eps = []
|
old_eps = []
|
||||||
|
|
||||||
def get_x_prev_and_pred_x0(e_t, index):
|
def get_x_prev_and_pred_x0(e_t, index):
|
||||||
# select parameters corresponding to the currently considered timestep
|
# select parameters corresponding to the currently considered timestep
|
||||||
a_t = alphas[index].item() * s_in
|
a_t = alphas[index].item() * s_x
|
||||||
a_prev = alphas_prev[index].item() * s_in
|
a_prev = alphas_prev[index].item() * s_x
|
||||||
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in
|
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
|
||||||
|
|
||||||
# current prediction for x_0
|
# current prediction for x_0
|
||||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||||
|
|
Loading…
Reference in New Issue