This commit is contained in:
Patrick von Platen 2022-06-24 17:35:19 +00:00
parent 7ca832cac9
commit fc67917a18
1 changed files with 10 additions and 9 deletions

19
run.py
View File

@ -269,20 +269,21 @@ with torch.no_grad():
for i in range(sde.N):
t = timesteps[i]
vec_t = torch.ones(shape[0], device=t.device) * t
# x, x_mean = corrector_update_fn(x, vec_t, model=model)
# x, x_mean = predictor_update_fn(x, vec_t, model=model)
x, x_mean = new_corrector.update_fn(x, vec_t)
x, x_mean = new_predictor.update_fn(x, vec_t)
x, x_mean = corrector_update_fn(x, vec_t, model=model)
x, x_mean = predictor_update_fn(x, vec_t, model=model)
# x, x_mean = new_corrector.update_fn(x, vec_t)
# x, x_mean = new_predictor.update_fn(x, vec_t)
x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
save_image(x)
# for 5
#assert x.abs().sum().cpu().item() - 106114.90625 < 1e-2, "sum wrong"
#assert x.abs().mean().cpu().item() - 34.5426139831543 < 1e-4, "mean wrong"
#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
# for 1000
assert x.abs().sum().cpu().item() - 436.5811 < 1e-2, "sum wrong"
assert x.abs().mean().cpu().item() - 0.1421 < 1e-4, "mean wrong"
assert (x.abs().sum() - 436.5811).abs().sum().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
assert (x.abs().mean() - 0.1421).abs().mean().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
save_image(x)