Fix scheduler inference steps error with power of 3 (#466)
* initial attempt at solving * fix pndm power of 3 inference_step * add power of 3 test * fix index in pndm test, remove ddim test * add comments, change to round()
This commit is contained in:
parent
da990633a9
commit
b56f102765
|
@ -145,9 +145,10 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.arange(
|
||||
0, self.config.num_train_timesteps, self.config.num_train_timesteps // self.num_inference_steps
|
||||
)[::-1].copy()
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
self.timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
||||
self.timesteps += offset
|
||||
self.set_format(tensor_format=self.tensor_format)
|
||||
|
||||
|
|
|
@ -143,9 +143,10 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
optional value to shift timestep values up by. A value of 1 is used in stable diffusion for inference.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self._timesteps = list(
|
||||
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
|
||||
)
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
# casting to int to avoid issues when num_inference_step is power of 3
|
||||
self._timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().tolist()
|
||||
self._offset = offset
|
||||
self._timesteps = np.array([t + self._offset for t in self._timesteps])
|
||||
|
||||
|
|
|
@ -379,7 +379,7 @@ class DDIMSchedulerTest(SchedulerCommonTest):
|
|||
|
||||
def test_inference_steps(self):
|
||||
for t, num_inference_steps in zip([1, 10, 50], [10, 50, 500]):
|
||||
self.check_over_forward(num_inference_steps=num_inference_steps)
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_eta(self):
|
||||
for t, eta in zip([1, 10, 49], [0.0, 0.5, 1.0]):
|
||||
|
@ -622,6 +622,23 @@ class PNDMSchedulerTest(SchedulerCommonTest):
|
|||
for t, num_inference_steps in zip([1, 5, 10], [10, 50, 100]):
|
||||
self.check_over_forward(time_step=t, num_inference_steps=num_inference_steps)
|
||||
|
||||
def test_pow_of_3_inference_steps(self):
|
||||
# earlier version of set_timesteps() caused an error indexing alpha's with inference steps as power of 3
|
||||
num_inference_steps = 27
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# before power of 3 fix, would error on first step, so we only need to do two
|
||||
for i, t in enumerate(scheduler.prk_timesteps[:2]):
|
||||
sample = scheduler.step_prk(residual, t, sample).prev_sample
|
||||
|
||||
def test_inference_plms_no_past_residuals(self):
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
|
|
Loading…
Reference in New Issue