Fix DDIM on Windows not using int64 for timesteps (#819)
This commit is contained in:
parent
728a3f3ec1
commit
a3efa433ea
|
@ -149,7 +149,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
# setable values
|
# setable values
|
||||||
self.num_inference_steps = None
|
self.num_inference_steps = None
|
||||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||||
|
|
||||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||||
"""
|
"""
|
||||||
|
@ -192,7 +192,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||||
# creates integer timesteps by multiplying by ratio
|
# creates integer timesteps by multiplying by ratio
|
||||||
# casting to int to avoid issues when num_inference_step is power of 3
|
# casting to int to avoid issues when num_inference_step is power of 3
|
||||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy()
|
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||||
self.timesteps += offset
|
self.timesteps += offset
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue