[Scheduler design] The pragmatic approach (#719)
* init * improve add_noise * [debug start] run slow test * [debug end] * quick revert * Add docstrings and warnings + API tests * Make the warning less spammy
This commit is contained in:
parent
726aba089d
commit
6b09f370c4
|
@ -57,7 +57,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
|
||||
model = self.unet
|
||||
|
||||
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
|
||||
sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
|
||||
sample = sample.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
|
|
@ -281,9 +281,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = latents * self.scheduler.sigmas[0]
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
@ -297,10 +296,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[i]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
@ -311,10 +307,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
|
|
|
@ -226,13 +226,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
timesteps = torch.tensor(
|
||||
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
|
||||
)
|
||||
else:
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
|
@ -310,16 +306,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
t_index = t_start + i
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[t_index]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
@ -330,10 +319,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
|
|
|
@ -260,13 +260,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
timesteps = torch.tensor(
|
||||
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
|
||||
)
|
||||
else:
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
|
||||
|
@ -348,13 +344,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
|
||||
|
||||
for i, t in tqdm(enumerate(timesteps)):
|
||||
t_index = t_start + i
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[t_index]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
||||
|
@ -365,14 +357,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
# masking
|
||||
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
|
||||
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
|
|
|
@ -147,9 +147,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
|||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = latents * self.scheduler.sigmas[0]
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
|
@ -163,10 +161,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
|||
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
sigma = self.scheduler.sigmas[i]
|
||||
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
|
||||
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
|
@ -180,11 +175,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
|
|||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
if isinstance(self.scheduler, LMSDiscreteScheduler):
|
||||
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
|
||||
else:
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
latents = np.array(latents)
|
||||
|
||||
# call the callback, if provided
|
||||
|
|
|
@ -69,7 +69,7 @@ class KarrasVePipeline(DiffusionPipeline):
|
|||
model = self.unet
|
||||
|
||||
# sample x_0 ~ N(0, sigma_0^2 * I)
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = torch.randn(*shape) * self.scheduler.init_noise_sigma
|
||||
sample = sample.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
|
|
@ -152,10 +152,27 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
|||
# whether we use the final alpha of the "non-previous" one.
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
|
|
|
@ -140,12 +140,29 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
|||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.one = torch.tensor(1.0)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
|
|
@ -95,11 +95,28 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
take_from=kwargs,
|
||||
)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = sigma_max
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps: int = None
|
||||
self.timesteps: np.IntTensor = None
|
||||
self.schedule: torch.FloatTensor = None # sigma(t_i)
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
|
@ -102,11 +102,36 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
|
||||
self.sigmas = torch.from_numpy(sigmas)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = self.sigmas.max()
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps)
|
||||
self.derivatives = []
|
||||
self.is_scale_input_called = False
|
||||
|
||||
def scale_model_input(
|
||||
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
|
||||
def get_lms_coefficient(self, order, t, current_order):
|
||||
"""
|
||||
|
@ -154,7 +179,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
timestep: Union[float, torch.FloatTensor],
|
||||
sample: torch.FloatTensor,
|
||||
order: int = 4,
|
||||
return_dict: bool = True,
|
||||
|
@ -165,7 +190,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
timestep (`float`): current timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
order: coefficient for multi-step inference.
|
||||
|
@ -177,7 +202,21 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
sigma = self.sigmas[timestep]
|
||||
if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor):
|
||||
warnings.warn(
|
||||
f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. "
|
||||
"Make sure to pass one of the `scheduler.timesteps`"
|
||||
)
|
||||
if not self.is_scale_input_called:
|
||||
warnings.warn(
|
||||
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
|
||||
"See `StableDiffusionPipeline` for a usage example."
|
||||
)
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
||||
pred_original_sample = sample - sigma * model_output
|
||||
|
@ -189,8 +228,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
self.derivatives.pop(0)
|
||||
|
||||
# 3. Compute linear multistep coefficients
|
||||
order = min(timestep + 1, order)
|
||||
lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
|
||||
order = min(step_index + 1, order)
|
||||
lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
|
||||
|
||||
# 4. Compute previous sample based on the derivatives path
|
||||
prev_sample = sample + sum(
|
||||
|
@ -206,12 +245,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
|||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
sigmas = self.sigmas.to(original_samples.device)
|
||||
schedule_timesteps = self.timesteps.to(original_samples.device)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
||||
|
||||
sigma = sigmas[timesteps].flatten()
|
||||
sigma = sigmas[step_indices].flatten()
|
||||
while len(sigma.shape) < len(original_samples.shape):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
|
||||
|
|
|
@ -129,6 +129,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# For now we only support F-PNDM, i.e. the runge-kutta method
|
||||
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
|
||||
# mainly at formula (9), (12), (13) and the Algorithm 2.
|
||||
|
@ -342,6 +345,19 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
|||
|
||||
return SchedulerOutput(prev_sample=prev_sample)
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
|
||||
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
||||
# this function computes x_(t−δ) using the formula of (9)
|
||||
|
|
|
@ -84,11 +84,28 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
|
|||
take_from=kwargs,
|
||||
)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = sigma_max
|
||||
|
||||
# setable values
|
||||
self.timesteps = None
|
||||
|
||||
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(
|
||||
self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
|
||||
):
|
||||
|
|
|
@ -201,7 +201,7 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
)
|
||||
|
||||
kwargs = dict(self.forward_default_kwargs)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", 50)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
|
@ -226,6 +226,27 @@ class SchedulerCommonTest(unittest.TestCase):
|
|||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_scheduler_public_api(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "init_noise_sigma"),
|
||||
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "scale_model_input"),
|
||||
f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`",
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "step"),
|
||||
f"{scheduler_class} does not implement a required class method `step(...)`",
|
||||
)
|
||||
|
||||
sample = self.dummy_sample
|
||||
scaled_sample = scheduler.scale_model_input(sample, 0.0)
|
||||
self.assertEqual(sample.shape, scaled_sample.shape)
|
||||
|
||||
|
||||
class DDPMSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (DDPMScheduler,)
|
||||
|
@ -865,14 +886,14 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
|
|||
scheduler.set_timesteps(self.num_inference_steps)
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter * scheduler.sigmas[0]
|
||||
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
|
||||
|
||||
for i, t in enumerate(scheduler.timesteps):
|
||||
sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5)
|
||||
sample = scheduler.scale_model_input(sample, t)
|
||||
|
||||
model_output = model(sample, t)
|
||||
|
||||
output = scheduler.step(model_output, i, sample)
|
||||
output = scheduler.step(model_output, t, sample)
|
||||
sample = output.prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
|
|
Loading…
Reference in New Issue