[Scheduler design] The pragmatic approach (#719)

* init

* improve add_noise

* [debug start] run slow test

* [debug end]

* quick revert

* Add docstrings and warnings + API tests

* Make the warning less spammy
This commit is contained in:
Anton Lozhkov 2022-10-05 14:41:19 +02:00 committed by GitHub
parent 726aba089d
commit 6b09f370c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 179 additions and 76 deletions

View File

@ -57,7 +57,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.unet
sample = torch.randn(*shape, generator=generator) * self.scheduler.config.sigma_max
sample = torch.randn(*shape, generator=generator) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)

View File

@ -281,9 +281,8 @@ class StableDiffusionPipeline(DiffusionPipeline):
# It's more optimized to move all timesteps to correct device beforehand
timesteps_tensor = self.scheduler.timesteps.to(self.device)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@ -297,10 +296,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@ -311,10 +307,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:

View File

@ -226,13 +226,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
@ -310,16 +306,9 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in enumerate(self.progress_bar(timesteps)):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@ -330,10 +319,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:

View File

@ -260,13 +260,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
offset = self.scheduler.config.get("steps_offset", 0)
init_timestep = int(num_inference_steps * strength) + offset
init_timestep = min(init_timestep, num_inference_steps)
if isinstance(self.scheduler, LMSDiscreteScheduler):
timesteps = torch.tensor(
[num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device
)
else:
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, dtype=torch.long, device=self.device)
timesteps = self.scheduler.timesteps[-init_timestep]
timesteps = torch.tensor([timesteps] * batch_size, device=self.device)
# add noise to latents using the timesteps
noise = torch.randn(init_latents.shape, generator=generator, device=self.device)
@ -348,13 +344,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
for i, t in tqdm(enumerate(timesteps)):
t_index = t_start + i
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[t_index]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
@ -365,14 +357,9 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, t_index, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t_index]))
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.LongTensor([t]))
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# masking
init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
latents = (init_latents_proper * mask) + (latents * (1 - mask))

View File

@ -147,9 +147,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
# set timesteps
self.scheduler.set_timesteps(num_inference_steps)
# if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = latents * self.scheduler.sigmas[0]
latents = latents * self.scheduler.init_noise_sigma
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@ -163,10 +161,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
if isinstance(self.scheduler, LMSDiscreteScheduler):
sigma = self.scheduler.sigmas[i]
# the model input needs to be scaled to match the continuous ODE formulation in K-LMS
latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
@ -180,11 +175,7 @@ class StableDiffusionOnnxPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
if isinstance(self.scheduler, LMSDiscreteScheduler):
latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
latents = np.array(latents)
# call the callback, if provided

View File

@ -69,7 +69,7 @@ class KarrasVePipeline(DiffusionPipeline):
model = self.unet
# sample x_0 ~ N(0, sigma_0^2 * I)
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
sample = torch.randn(*shape) * self.scheduler.init_noise_sigma
sample = sample.to(self.device)
self.scheduler.set_timesteps(num_inference_steps)

View File

@ -152,10 +152,27 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod

View File

@ -140,12 +140,29 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.variance_type = variance_type
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.

View File

@ -95,11 +95,28 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
take_from=kwargs,
)
# standard deviation of the initial noise distribution
self.init_noise_sigma = sigma_max
# setable values
self.num_inference_steps: int = None
self.timesteps: np.IntTensor = None
self.schedule: torch.FloatTensor = None # sigma(t_i)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the continuous timesteps used for the diffusion chain. Supporting function to be run before inference.

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
@ -102,11 +102,36 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
self.sigmas = torch.from_numpy(sigmas)
# standard deviation of the initial noise distribution
self.init_noise_sigma = self.sigmas.max()
# setable values
self.num_inference_steps = None
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps)
self.derivatives = []
self.is_scale_input_called = False
def scale_model_input(
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
) -> torch.FloatTensor:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the K-LMS algorithm.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain
Returns:
`torch.FloatTensor`: scaled input sample
"""
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
self.is_scale_input_called = True
return sample
def get_lms_coefficient(self, order, t, current_order):
"""
@ -154,7 +179,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
order: int = 4,
return_dict: bool = True,
@ -165,7 +190,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
timestep (`float`): current timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
@ -177,7 +202,21 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
When returning a tuple, the first element is the sample tensor.
"""
sigma = self.sigmas[timestep]
if not isinstance(timestep, float) and not isinstance(timestep, torch.FloatTensor):
warnings.warn(
f"`LMSDiscreteScheduler` timesteps must be `float` or `torch.FloatTensor`, not {type(timestep)}. "
"Make sure to pass one of the `scheduler.timesteps`"
)
if not self.is_scale_input_called:
warnings.warn(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
)
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
step_index = (self.timesteps == timestep).nonzero().item()
sigma = self.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
pred_original_sample = sample - sigma * model_output
@ -189,8 +228,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self.derivatives.pop(0)
# 3. Compute linear multistep coefficients
order = min(timestep + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, timestep, curr_order) for curr_order in range(order)]
order = min(step_index + 1, order)
lms_coeffs = [self.get_lms_coefficient(order, step_index, curr_order) for curr_order in range(order)]
# 4. Compute previous sample based on the derivatives path
prev_sample = sample + sum(
@ -206,12 +245,14 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
self,
original_samples: torch.FloatTensor,
noise: torch.FloatTensor,
timesteps: torch.IntTensor,
timesteps: torch.FloatTensor,
) -> torch.FloatTensor:
sigmas = self.sigmas.to(original_samples.device)
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma = sigmas[timesteps].flatten()
sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

View File

@ -129,6 +129,9 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# For now we only support F-PNDM, i.e. the runge-kutta method
# For more information on the algorithm please take a look at the paper: https://arxiv.org/pdf/2202.09778.pdf
# mainly at formula (9), (12), (13) and the Algorithm 2.
@ -342,6 +345,19 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
return SchedulerOutput(prev_sample=prev_sample)
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
# this function computes x_(tδ) using the formula of (9)

View File

@ -84,11 +84,28 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
take_from=kwargs,
)
# standard deviation of the initial noise distribution
self.init_noise_sigma = sigma_max
# setable values
self.timesteps = None
self.set_sigmas(num_train_timesteps, sigma_min, sigma_max, sampling_eps)
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps(
self, num_inference_steps: int, sampling_eps: float = None, device: Union[str, torch.device] = None
):

View File

@ -201,7 +201,7 @@ class SchedulerCommonTest(unittest.TestCase):
)
kwargs = dict(self.forward_default_kwargs)
num_inference_steps = kwargs.pop("num_inference_steps", None)
num_inference_steps = kwargs.pop("num_inference_steps", 50)
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
@ -226,6 +226,27 @@ class SchedulerCommonTest(unittest.TestCase):
recursive_check(outputs_tuple, outputs_dict)
def test_scheduler_public_api(self):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
self.assertTrue(
hasattr(scheduler, "init_noise_sigma"),
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
)
self.assertTrue(
hasattr(scheduler, "scale_model_input"),
f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`",
)
self.assertTrue(
hasattr(scheduler, "step"),
f"{scheduler_class} does not implement a required class method `step(...)`",
)
sample = self.dummy_sample
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,)
@ -865,14 +886,14 @@ class LMSDiscreteSchedulerTest(SchedulerCommonTest):
scheduler.set_timesteps(self.num_inference_steps)
model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.sigmas[0]
sample = self.dummy_sample_deter * scheduler.init_noise_sigma
for i, t in enumerate(scheduler.timesteps):
sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5)
sample = scheduler.scale_model_input(sample, t)
model_output = model(sample, t)
output = scheduler.step(model_output, i, sample)
output = scheduler.step(model_output, t, sample)
sample = output.prev_sample
result_sum = torch.sum(torch.abs(sample))