Fix scheduler inference steps error with power of 3 (#466)

* initial attempt at solving

* fix pndm power of 3 inference_step

* add power of 3 test

* fix index in pndm test, remove ddim test

* add comments, change to round()
This commit is contained in:
Nathan Lambert 2022-09-13 09:48:33 -06:00 committed by GitHub
parent da990633a9
commit b56f102765
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 7 deletions

View File

@ -145,9 +145,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
self.num_inference_steps = num_inference_steps
self.timesteps = np.arange(
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
)[::-1].copy()
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
self.timesteps += offset
self.set_format(tensor_format=self.tensor_format)

View File

@ -143,9 +143,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
"""
self.num_inference_steps = num_inference_steps
self._timesteps = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist()
self._offset = offset
self._timesteps = np.array([t + self._offset for t in self._timesteps])

View File

@ -379,7 +379,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
def test_inference_steps(self):
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
self.check_over_forward(num_inference_steps=num_inference_steps)
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_eta(self):
for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
@ -622,6 +622,23 @@ class PNDMSchedulerTest(SchedulerCommonTest):
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
def test_pow_of_3_inference_steps(self):
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
num_inference_steps = 27
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(num_inference_steps)
# before power of 3 fix, would error on first step, so we only need to do two
for i, t in enumerate(scheduler.prk_timesteps[:2]):
sample = scheduler.step_prk(residual, t, sample).prev_sample
def test_inference_plms_no_past_residuals(self):
with self.assertRaises(ValueError):
scheduler_class = self.scheduler_classes[0]