[Type hint] scheduling lms discrete (#360)
* [Type hint] scheduling karras ve * [Type hint] scheduling lms discrete
This commit is contained in:
parent
3c1cdd3359
commit
be52be7215
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Tuple, Union
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -27,13 +27,13 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_train_timesteps=1000,
|
num_train_timesteps: int = 1000,
|
||||||
beta_start=0.0001,
|
beta_start: float = 0.0001,
|
||||||
beta_end=0.02,
|
beta_end: float = 0.02,
|
||||||
beta_schedule="linear",
|
beta_schedule: str = "linear",
|
||||||
trained_betas=None,
|
trained_betas: Optional[np.ndarray] = None,
|
||||||
timestep_values=None,
|
timestep_values: Optional[np.ndarray] = None,
|
||||||
tensor_format="pt",
|
tensor_format: str = "pt",
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
|
Linear Multistep Scheduler for discrete beta schedules. Based on the original k-diffusion implementation by
|
||||||
|
@ -79,7 +79,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
return integrated_coeff
|
return integrated_coeff
|
||||||
|
|
||||||
def set_timesteps(self, num_inference_steps):
|
def set_timesteps(self, num_inference_steps: int):
|
||||||
self.num_inference_steps = num_inference_steps
|
self.num_inference_steps = num_inference_steps
|
||||||
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
|
self.timesteps = np.linspace(self.num_train_timesteps - 1, 0, num_inference_steps, dtype=float)
|
||||||
|
|
||||||
|
@ -127,7 +127,12 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||||
|
|
||||||
return SchedulerOutput(prev_sample=prev_sample)
|
return SchedulerOutput(prev_sample=prev_sample)
|
||||||
|
|
||||||
def add_noise(self, original_samples, noise, timesteps):
|
def add_noise(
|
||||||
|
self,
|
||||||
|
original_samples: Union[torch.FloatTensor, np.ndarray],
|
||||||
|
noise: Union[torch.FloatTensor, np.ndarray],
|
||||||
|
timesteps: Union[torch.IntTensor, np.ndarray],
|
||||||
|
) -> Union[torch.FloatTensor, np.ndarray]:
|
||||||
sigmas = self.match_shape(self.sigmas[timesteps], noise)
|
sigmas = self.match_shape(self.sigmas[timesteps], noise)
|
||||||
noisy_samples = original_samples + noise * sigmas
|
noisy_samples = original_samples + noise * sigmas
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue