From fc67917a181a4cbd539c794470948ffeb89e5b1d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 24 Jun 2022 17:35:19 +0000 Subject: [PATCH] up --- run.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/run.py b/run.py index 61e29603..b2ec6eea 100755 --- a/run.py +++ b/run.py @@ -269,20 +269,21 @@ with torch.no_grad(): for i in range(sde.N): t = timesteps[i] vec_t = torch.ones(shape[0], device=t.device) * t -# x, x_mean = corrector_update_fn(x, vec_t, model=model) -# x, x_mean = predictor_update_fn(x, vec_t, model=model) - x, x_mean = new_corrector.update_fn(x, vec_t) - x, x_mean = new_predictor.update_fn(x, vec_t) + x, x_mean = corrector_update_fn(x, vec_t, model=model) + x, x_mean = predictor_update_fn(x, vec_t, model=model) +# x, x_mean = new_corrector.update_fn(x, vec_t) +# x, x_mean = new_predictor.update_fn(x, vec_t) x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1) +save_image(x) + # for 5 -#assert x.abs().sum().cpu().item() - 106114.90625 < 1e-2, "sum wrong" -#assert x.abs().mean().cpu().item() - 34.5426139831543 < 1e-4, "mean wrong" +#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" +#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" # for 1000 -assert x.abs().sum().cpu().item() - 436.5811 < 1e-2, "sum wrong" -assert x.abs().mean().cpu().item() - 0.1421 < 1e-4, "mean wrong" +assert (x.abs().sum() - 436.5811).abs().sum().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}" +assert (x.abs().mean() - 0.1421).abs().mean().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}" -save_image(x)