Add back-compatibility to LMS timesteps (#750)
* Add back-compatibility to LMS timesteps * style
This commit is contained in:
parent
c119dc4c04
commit
df9c070174
|
@ -202,11 +202,6 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor):
|
||||
warnings.warn(
|
||||
f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. "
|
||||
"Make sure to pass one of the `scheduler.timesteps`"
|
||||
)
|
||||
if not self.is_scale_input_called:
|
||||
warnings.warn(
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
|
@ -215,7 +210,18 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
if (
|
||||
isinstance(timestep, int)
|
||||
or isinstance(timestep, torch.IntTensor)
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
warnings.warn(
|
||||
"Integer timesteps in `LMSDiscreteScheduler.step()` are deprecated and will be removed in version"
|
||||
" 0.5.0. Make sure to pass one of the `scheduler.timesteps`."
|
||||
)
|
||||
step_index = timestep
|
||||
else:
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
|
@ -250,7 +256,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
sigmas = self.sigmas.to(original_samples.device)
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor):
|
||||
warnings.warn(
|
||||
"Integer timesteps in `LMSDiscreteScheduler.add_noise()` are deprecated and will be removed in"
|
||||
" version 0.5.0. Make sure to pass values from `scheduler.timesteps`."
|
||||
)
|
||||
step_indices = timesteps
|
||||
else:
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
|
|
Loading…
Reference in New Issue