Fix the LMS pytorch regression (#664)

* Fix the LMS pytorch regression

* Copy over the changes from #637

* Copy over the changes from #637

* Fix betas test
This commit is contained in:
Anton Lozhkov 2022-09-28 14:07:26 +02:00 committed by GitHub
parent 235770dd84
commit 765506ce28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 11 deletions

View File

@ -17,7 +17,6 @@ deps = {
"jaxlib": "jaxlib>=0.1.65,<=0.3.6",
"modelcards": "modelcards>=0.1.4",
"numpy": "numpy",
"onnxruntime": "onnxruntime",
"onnxruntime-gpu": "onnxruntime-gpu",
"pytest": "pytest",
"pytest-timeout": "pytest-timeout",

View File

@ -99,11 +99,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.sigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, num_train_timesteps)[::-1] # to be consistent has to be smaller than sigmas by 1
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.derivatives = []
def get_lms_coefficient(self, order, t, current_order):
@ -137,17 +140,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
low_idx = np.floor(timesteps).astype(int)
high_idx = np.ceil(timesteps).astype(int)
frac = np.mod(timesteps, 1.0)
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
sigmas = (1 - frac) * sigmas[low_idx] + frac * sigmas[high_idx]
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
self.timesteps = torch.from_numpy(timesteps)
self.timesteps = timesteps.astype(int)
self.derivatives = []
def step(

View File

@ -844,7 +844,7 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
self.check_over_configs(num_train_timesteps=timesteps)
def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
for beta_start, beta_end in zip([0.00001, 0.0001, 0.001], [0.0002, 0.002, 0.02]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)
def test_schedules(self):
@ -876,5 +876,5 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 1006.370) < 1e-2
assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3