[Type hint] scheduling lms discrete (#360)
* [Type hint] scheduling karras ve * [Type hint] scheduling lms discrete
This commit is contained in:
parent
3c1cdd3359
commit
be52be7215
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple, Union
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -27,13 +27,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps=1000,
|
||||
beta_start=0.0001,
|
||||
beta_end=0.02,
|
||||
beta_schedule="linear",
|
||||
trained_betas=None,
|
||||
timestep_values=None,
|
||||
tensor_format="pt",
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[np.ndarray] = None,
|
||||
timestep_values: Optional[np.ndarray] = None,
|
||||
tensor_format: str = "pt",
|
||||
):
|
||||
"""
|
||||
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
|
||||
|
@ -79,7 +79,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return integrated_coeff
|
||||
|
||||
def set_timesteps(self, num_inference_steps):
|
||||
def set_timesteps(self, num_inference_steps: int):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
|
||||
|
||||
|
@ -127,7 +127,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def add_noise(self, original_samples, noise, timesteps):
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: Union[torch.FloatTensor, np.ndarray],
|
||||
noise: Union[torch.FloatTensor, np.ndarray],
|
||||
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||
) -> Union[torch.FloatTensor, np.ndarray]:
|
||||
sigmas = self.match_shape(self.sigmas[timesteps], noise)
|
||||
noisy_samples = original_samples + noise * sigmas
|
||||
|
||||
|
|
Loading…
Reference in New Issue