This commit is contained in:
parent
7ca832cac9
commit
fc67917a18
19
run.py
19
run.py
|
@ -269,20 +269,21 @@ with torch.no_grad():
|
|||
for i in range(sde.N):
|
||||
t = timesteps[i]
|
||||
vec_t = torch.ones(shape[0], device=t.device) * t
|
||||
# x, x_mean = corrector_update_fn(x, vec_t, model=model)
|
||||
# x, x_mean = predictor_update_fn(x, vec_t, model=model)
|
||||
x, x_mean = new_corrector.update_fn(x, vec_t)
|
||||
x, x_mean = new_predictor.update_fn(x, vec_t)
|
||||
x, x_mean = corrector_update_fn(x, vec_t, model=model)
|
||||
x, x_mean = predictor_update_fn(x, vec_t, model=model)
|
||||
# x, x_mean = new_corrector.update_fn(x, vec_t)
|
||||
# x, x_mean = new_predictor.update_fn(x, vec_t)
|
||||
|
||||
x, n = inverse_scaler(x_mean if denoise else x), sde.N * (n_steps + 1)
|
||||
|
||||
|
||||
save_image(x)
|
||||
|
||||
# for 5
|
||||
#assert x.abs().sum().cpu().item() - 106114.90625 < 1e-2, "sum wrong"
|
||||
#assert x.abs().mean().cpu().item() - 34.5426139831543 < 1e-4, "mean wrong"
|
||||
#assert (x.abs().sum() - 106114.90625).cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
|
||||
#assert (x.abs().mean() - 34.5426139831543).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
||||
|
||||
# for 1000
|
||||
assert x.abs().sum().cpu().item() - 436.5811 < 1e-2, "sum wrong"
|
||||
assert x.abs().mean().cpu().item() - 0.1421 < 1e-4, "mean wrong"
|
||||
assert (x.abs().sum() - 436.5811).abs().sum().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
|
||||
assert (x.abs().mean() - 0.1421).abs().mean().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
|
||||
|
||||
save_image(x)
|
||||
|
|
Loading…
Reference in New Issue