vectorize kl-optimal sigma calculation
Co-authored-by: mamei16 <marcel.1710@live.de>
This commit is contained in:
parent
83266205d0
commit
3a215deff2
|
@ -34,9 +34,8 @@ def sgm_uniform(n, sigma_min, sigma_max, inner_model, device):
|
|||
def kl_optimal(n, sigma_min, sigma_max, device):
|
||||
alpha_min = torch.arctan(torch.tensor(sigma_min, device=device))
|
||||
alpha_max = torch.arctan(torch.tensor(sigma_max, device=device))
|
||||
sigmas = torch.empty((n+1,), device=device)
|
||||
for i in range(n+1):
|
||||
sigmas[i] = torch.tan((i/n) * alpha_min + (1.0-i/n) * alpha_max)
|
||||
step_indices = torch.arange(n + 1, device=device)
|
||||
sigmas = torch.tan(step_indices / n * alpha_min + (1.0 - step_indices / n) * alpha_max)
|
||||
return sigmas
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue