[Type hint] scheduling lms discrete (#360)

* [Type hint] scheduling karras ve

* [Type hint] scheduling lms discrete
This commit is contained in:
Santiago Víquez 2022-09-05 18:28:49 +02:00 committed by GitHub
parent 3c1cdd3359
commit be52be7215
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 10 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Union
from typing import Optional, Tuple, Union
import numpy as np
import torch
@ -27,13 +27,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
trained_betas=None,
timestep_values=None,
tensor_format="pt",
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
timestep_values: Optional[np.ndarray] = None,
tensor_format: str = "pt",
):
"""
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
@ -79,7 +79,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return integrated_coeff
def set_timesteps(self, num_inference_steps):
def set_timesteps(self, num_inference_steps: int):
self.num_inference_steps = num_inference_steps
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
@ -127,7 +127,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def add_noise(self, original_samples, noise, timesteps):
def add_noise(
self,
original_samples: Union[torch.FloatTensor, np.ndarray],
noise: Union[torch.FloatTensor, np.ndarray],
timesteps: Union[torch.IntTensor, np.ndarray],
) -> Union[torch.FloatTensor, np.ndarray]:
sigmas = self.match_shape(self.sigmas[timesteps], noise)
noisy_samples = original_samples + noise * sigmas