From 7c82a16fc14840429566aec40eb9e65aa57005fd Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 19 Jan 2023 11:35:55 +0100 Subject: [PATCH] Fix EMA for multi-gpu training in the unconditional example (#1930) * improve EMA * style * one EMA model * quality * fix tests * fix test * Apply suggestions from code review Co-authored-by: Pedro Cuenca * re organise the unconditional script * backwards compatibility * default to init values for some args * fix ort script * issubclass => isinstance * update state_dict * docstr * doc * Apply suggestions from code review Co-authored-by: Pedro Cuenca * use .to if device is passed * deprecate device * make flake happy * fix typo Co-authored-by: patil-suraj Co-authored-by: Pedro Cuenca Co-authored-by: Patrick von Platen --- examples/text_to_image/train_text_to_image.py | 113 +------- .../unconditional_image_generation/README.md | 3 + .../train_unconditional.py | 139 +++++++--- .../train_unconditional_ort.py | 26 +- src/diffusers/training_utils.py | 247 ++++++++++++++---- tests/test_modeling_common.py | 4 +- 6 files changed, 315 insertions(+), 217 deletions(-) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index b77f9056..443403ca 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and import argparse -import copy import logging import math import os import random from pathlib import Path -from typing import Iterable, Optional +from typing import Optional import numpy as np import torch @@ -36,6 +35,7 @@ from accelerate.utils import set_seed from datasets import load_dataset from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version from diffusers.utils.import_utils import is_xformers_available from huggingface_hub import HfFolder, Repository, whoami @@ -305,115 +305,6 @@ dataset_name_mapping = { } -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 -class EMAModel: - """ - Exponential Moving Average of models weights - """ - - def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - - self.collected_params = None - - self.decay = decay - self.optimization_step = 0 - - @torch.no_grad() - def step(self, parameters): - parameters = list(parameters) - - self.optimization_step += 1 - - # Compute the decay factor for the exponential moving average. - value = (1 + self.optimization_step) / (10 + self.optimization_step) - one_minus_decay = 1 - min(self.decay, value) - - for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - s_param.sub_(one_minus_decay * (s_param - param)) - else: - s_param.copy_(param) - - torch.cuda.empty_cache() - - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Copy current averaged parameters into given collection of parameters. - - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def to(self, device=None, dtype=None) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. - - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] - - def state_dict(self) -> dict: - r""" - Returns the state of the ExponentialMovingAverage as a dict. - This method is used by accelerate during checkpointing to save the ema state dict. - """ - # Following PyTorch conventions, references to tensors are returned: - # "returns a reference to the state and not its copy!" - - # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict - return { - "decay": self.decay, - "optimization_step": self.optimization_step, - "shadow_params": self.shadow_params, - "collected_params": self.collected_params, - } - - def load_state_dict(self, state_dict: dict) -> None: - r""" - Loads the ExponentialMovingAverage state. - This method is used by accelerate during checkpointing to save the ema state dict. - Args: - state_dict (dict): EMA state. Should be an object returned - from a call to :meth:`state_dict`. - """ - # deepcopy, to be consistent with module API - state_dict = copy.deepcopy(state_dict) - - self.decay = state_dict["decay"] - if self.decay < 0.0 or self.decay > 1.0: - raise ValueError("Decay must be between 0 and 1") - - self.optimization_step = state_dict["optimization_step"] - if not isinstance(self.optimization_step, int): - raise ValueError("Invalid optimization_step") - - self.shadow_params = state_dict["shadow_params"] - if not isinstance(self.shadow_params, list): - raise ValueError("shadow_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): - raise ValueError("shadow_params must all be Tensors") - - self.collected_params = state_dict["collected_params"] - if self.collected_params is not None: - if not isinstance(self.collected_params, list): - raise ValueError("collected_params must be a list") - if not all(isinstance(p, torch.Tensor) for p in self.collected_params): - raise ValueError("collected_params must all be Tensors") - if len(self.collected_params) != len(self.shadow_params): - raise ValueError("collected_params and shadow_params must have the same length") - - def main(): args = parse_args() logging_dir = os.path.join(args.output_dir, args.logging_dir) diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index 010200b5..63616b8a 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -39,6 +39,7 @@ accelerate launch train_unconditional.py \ --train_batch_size=16 \ --num_epochs=100 \ --gradient_accumulation_steps=1 \ + --use_ema \ --learning_rate=1e-4 \ --lr_warmup_steps=500 \ --mixed_precision=no \ @@ -63,6 +64,7 @@ accelerate launch train_unconditional.py \ --train_batch_size=16 \ --num_epochs=100 \ --gradient_accumulation_steps=1 \ + --use_ema \ --learning_rate=1e-4 \ --lr_warmup_steps=500 \ --mixed_precision=no \ @@ -150,6 +152,7 @@ accelerate launch train_unconditional_ort.py \ --dataset_name="huggan/flowers-102-categories" \ --resolution=64 \ --output_dir="ddpm-ema-flowers-64" \ + --use_ema \ --train_batch_size=16 \ --num_epochs=1 \ --gradient_accumulation_steps=1 \ diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 25d4b879..4b92028a 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -1,5 +1,7 @@ import argparse +import copy import inspect +import logging import math import os from pathlib import Path @@ -8,6 +10,8 @@ from typing import Optional import torch import torch.nn.functional as F +import datasets +import diffusers from accelerate import Accelerator from accelerate.logging import get_logger from datasets import load_dataset @@ -29,10 +33,10 @@ from tqdm.auto import tqdm # Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.10.0.dev0") +check_min_version("0.12.0.dev0") -logger = get_logger(__name__) +logger = get_logger(__name__, log_level="INFO") def _extract_into_tensor(arr, timesteps, broadcast_shape): @@ -156,7 +160,6 @@ def parse_args(): parser.add_argument( "--use_ema", action="store_true", - default=True, help="Whether to use Exponential Moving Average for the final model weights.", ) parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") @@ -255,6 +258,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: def main(args): logging_dir = os.path.join(args.output_dir, args.logging_dir) + accelerator = Accelerator( gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision, @@ -262,6 +266,38 @@ def main(args): logging_dir=logging_dir, ) + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Initialize the model model = UNet2DModel( sample_size=args.resolution, in_channels=3, @@ -285,8 +321,19 @@ def main(args): "UpBlock2D", ), ) - accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) + # Create EMA for the model. + if args.use_ema: + ema_model = EMAModel( + model.parameters(), + decay=args.ema_max_decay, + use_ema_warmup=True, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + ) + + # Initialize the scheduler + accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) if accepts_prediction_type: noise_scheduler = DDPMScheduler( num_train_timesteps=args.ddpm_num_steps, @@ -296,6 +343,7 @@ def main(args): else: noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule) + # Initialize the optimizer optimizer = torch.optim.AdamW( model.parameters(), lr=args.learning_rate, @@ -304,16 +352,11 @@ def main(args): eps=args.adam_epsilon, ) - augmentations = Compose( - [ - Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), - CenterCrop(args.resolution), - RandomHorizontalFlip(), - ToTensor(), - Normalize([0.5], [0.5]), - ] - ) + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. if args.dataset_name is not None: dataset = load_dataset( args.dataset_name, @@ -323,6 +366,19 @@ def main(args): ) else: dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train") + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + + # Preprocessing the datasets and DataLoaders creation. + augmentations = Compose( + [ + Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), + CenterCrop(args.resolution), + RandomHorizontalFlip(), + ToTensor(), + Normalize([0.5], [0.5]), + ] + ) def transforms(examples): images = [augmentations(image.convert("RGB")) for image in examples["image"]] @@ -335,6 +391,7 @@ def main(args): dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers ) + # Initialize the learning rate scheduler lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, @@ -342,44 +399,37 @@ def main(args): num_training_steps=(len(train_dataloader) * args.num_epochs), ) + # Prepare everything with our `accelerator`. model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) - accelerator.register_for_checkpointing(lr_scheduler) - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - - ema_model = EMAModel( - accelerator.unwrap_model(model), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) - - # Handle the repository creation - if accelerator.is_main_process: - if args.push_to_hub: - if args.hub_model_id is None: - repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) - else: - repo_name = args.hub_model_id - repo = Repository(args.output_dir, clone_from=repo_name) - - with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: - if "step_*" not in gitignore: - gitignore.write("step_*\n") - if "epoch_*" not in gitignore: - gitignore.write("epoch_*\n") - elif args.output_dir is not None: - os.makedirs(args.output_dir, exist_ok=True) + if args.use_ema: + accelerator.register_for_checkpointing(ema_model) + ema_model.to(accelerator.device) + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. if accelerator.is_main_process: run = os.path.split(__file__)[-1].split(".")[0] accelerator.init_trackers(run) + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + max_train_steps = args.num_epochs * num_update_steps_per_epoch + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(dataset)}") + logger.info(f" Num Epochs = {args.num_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") + global_step = 0 first_epoch = 0 + # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: if args.resume_from_checkpoint != "latest": path = os.path.basename(args.resume_from_checkpoint) @@ -397,6 +447,7 @@ def main(args): first_epoch = resume_global_step // num_update_steps_per_epoch resume_step = resume_global_step % num_update_steps_per_epoch + # Train! for epoch in range(first_epoch, args.num_epochs): model.train() progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process) @@ -445,12 +496,12 @@ def main(args): accelerator.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() lr_scheduler.step() - if args.use_ema: - ema_model.step(model) optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + if args.use_ema: + ema_model.step(model.parameters()) progress_bar.update(1) global_step += 1 @@ -472,8 +523,11 @@ def main(args): # Generate sample images for visual inspection if accelerator.is_main_process: if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1: + unet = copy.deepcopy(accelerator.unwrap_model(model)) + if args.use_ema: + ema_model.copy_to(unet.parameters()) pipeline = DDPMPipeline( - unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), + unet=unet, scheduler=noise_scheduler, ) @@ -498,7 +552,6 @@ def main(args): pipeline.save_pretrained(args.output_dir) if args.push_to_hub: repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False) - accelerator.wait_for_everyone() accelerator.end_training() diff --git a/examples/unconditional_image_generation/train_unconditional_ort.py b/examples/unconditional_image_generation/train_unconditional_ort.py index df0a4635..44ea5a2d 100644 --- a/examples/unconditional_image_generation/train_unconditional_ort.py +++ b/examples/unconditional_image_generation/train_unconditional_ort.py @@ -157,7 +157,6 @@ def parse_args(): parser.add_argument( "--use_ema", action="store_true", - default=True, help="Whether to use Exponential Moving Average for the final model weights.", ) parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.") @@ -287,8 +286,17 @@ def main(args): "UpBlock2D", ), ) - accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) + if args.use_ema: + ema_model = EMAModel( + model.parameters(), + decay=args.ema_max_decay, + use_ema_warmup=True, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + ) + + accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) if accepts_prediction_type: noise_scheduler = DDPMScheduler( num_train_timesteps=args.ddpm_num_steps, @@ -347,17 +355,13 @@ def main(args): model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler ) - accelerator.register_for_checkpointing(lr_scheduler) + + if args.use_ema: + accelerator.register_for_checkpointing(ema_model) + ema_model.to(accelerator.device) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) - ema_model = EMAModel( - accelerator.unwrap_model(model), - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay, - ) - model = ORTModule(model) # Handle the repository creation @@ -448,7 +452,7 @@ def main(args): optimizer.step() lr_scheduler.step() if args.use_ema: - ema_model.step(model) + ema_model.step(model.parameters()) optimizer.zero_grad() # Checks if the accelerator has performed an optimization step behind the scenes diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index fefc490c..7f43f553 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,10 +1,13 @@ import copy import os import random +from typing import Iterable, Union import numpy as np import torch +from .utils import deprecate + def enable_full_determinism(seed: int): """ @@ -39,6 +42,7 @@ def set_seed(seed: int): # ^^ safe to call this function even if cuda is not available +# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 class EMAModel: """ Exponential Moving Average of models weights @@ -46,81 +50,224 @@ class EMAModel: def __init__( self, - model, - update_after_step=0, - inv_gamma=1.0, - power=2 / 3, - min_value=0.0, - max_value=0.9999, - device=None, + parameters: Iterable[torch.nn.Parameter], + decay: float = 0.9999, + min_decay: float = 0.0, + update_after_step: int = 0, + use_ema_warmup: bool = False, + inv_gamma: Union[float, int] = 1.0, + power: Union[float, int] = 2 / 3, + **kwargs, ): """ + Args: + parameters (Iterable[torch.nn.Parameter]): The parameters to track. + decay (float): The decay factor for the exponential moving average. + min_decay (float): The minimum decay factor for the exponential moving average. + update_after_step (int): The number of steps to wait before starting to update the EMA weights. + use_ema_warmup (bool): Whether to use EMA warmup. + inv_gamma (float): + Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. + power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. + device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA + weights will be stored on CPU. + @crowsonkb's notes on EMA Warmup: If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at 215.4k steps). - - Args: - inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. - power (float): Exponential factor of EMA warmup. Default: 2/3. - min_value (float): The minimum EMA decay rate. Default: 0. """ - self.averaged_model = copy.deepcopy(model).eval() - self.averaged_model.requires_grad_(False) + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() + # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility + use_ema_warmup = True + + if kwargs.get("max_value", None) is not None: + deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." + deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False) + decay = kwargs["max_value"] + + if kwargs.get("min_value", None) is not None: + deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead." + deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False) + min_decay = kwargs["min_value"] + + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + + if kwargs.get("device", None) is not None: + deprecation_message = "The `device` argument is deprecated. Please use `to` instead." + deprecate("device", "1.0.0", deprecation_message, standard_warn=False) + self.to(device=kwargs["device"]) + + self.collected_params = None + + self.decay = decay + self.min_decay = min_decay self.update_after_step = update_after_step + self.use_ema_warmup = use_ema_warmup self.inv_gamma = inv_gamma self.power = power - self.min_value = min_value - self.max_value = max_value - - if device is not None: - self.averaged_model = self.averaged_model.to(device=device) - - self.decay = 0.0 self.optimization_step = 0 - def get_decay(self, optimization_step): + def get_decay(self, optimization_step: int) -> float: """ Compute the decay factor for the exponential moving average. """ step = max(0, optimization_step - self.update_after_step - 1) - value = 1 - (1 + step / self.inv_gamma) ** -self.power if step <= 0: return 0.0 - return max(self.min_value, min(value, self.max_value)) + if self.use_ema_warmup: + cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power + else: + cur_decay_value = (1 + step) / (10 + step) + + cur_decay_value = min(cur_decay_value, self.decay) + # make sure decay is not smaller than min_decay + cur_decay_value = max(cur_decay_value, self.min_decay) + return cur_decay_value @torch.no_grad() - def step(self, new_model): - ema_state_dict = {} - ema_params = self.averaged_model.state_dict() + def step(self, parameters: Iterable[torch.nn.Parameter]): + if isinstance(parameters, torch.nn.Module): + deprecation_message = ( + "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " + "Please pass the parameters of the module instead." + ) + deprecate( + "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", + "1.0.0", + deprecation_message, + standard_warn=False, + ) + parameters = parameters.parameters() - self.decay = self.get_decay(self.optimization_step) + parameters = list(parameters) - for key, param in new_model.named_parameters(): - if isinstance(param, dict): - continue - try: - ema_param = ema_params[key] - except KeyError: - ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) - ema_params[key] = ema_param - - if not param.requires_grad: - ema_params[key].copy_(param.to(dtype=ema_param.dtype).data) - ema_param = ema_params[key] - else: - ema_param.mul_(self.decay) - ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay) - - ema_state_dict[key] = ema_param - - for key, param in new_model.named_buffers(): - ema_state_dict[key] = param - - self.averaged_model.load_state_dict(ema_state_dict, strict=False) self.optimization_step += 1 + + # Compute the decay factor for the exponential moving average. + decay = self.get_decay(self.optimization_step) + one_minus_decay = 1 - decay + + for s_param, param in zip(self.shadow_params, parameters): + if param.requires_grad: + s_param.sub_(one_minus_decay * (s_param - param)) + else: + s_param.copy_(param) + + torch.cuda.empty_cache() + + def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the parameters with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = list(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) + for p in self.shadow_params + ] + + def state_dict(self) -> dict: + r""" + Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during + checkpointing to save the ema state dict. + """ + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "min_decay": self.decay, + "optimization_step": self.optimization_step, + "update_after_step": self.update_after_step, + "use_ema_warmup": self.use_ema_warmup, + "inv_gamma": self.inv_gamma, + "power": self.power, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r""" + Args: + Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the + ema state dict. + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + + self.decay = state_dict.get("decay", self.decay) + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.min_decay = state_dict.get("min_decay", self.min_decay) + if not isinstance(self.min_decay, float): + raise ValueError("Invalid min_decay") + + self.optimization_step = state_dict.get("optimization_step", self.optimization_step) + if not isinstance(self.optimization_step, int): + raise ValueError("Invalid optimization_step") + + self.update_after_step = state_dict.get("update_after_step", self.update_after_step) + if not isinstance(self.update_after_step, int): + raise ValueError("Invalid update_after_step") + + self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) + if not isinstance(self.use_ema_warmup, bool): + raise ValueError("Invalid use_ema_warmup") + + self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) + if not isinstance(self.inv_gamma, (float, int)): + raise ValueError("Invalid inv_gamma") + + self.power = state_dict["power"].get("power", self.power) + if not isinstance(self.power, (float, int)): + raise ValueError("Invalid power") + + self.shadow_params = state_dict["shadow_params"] + if not isinstance(self.shadow_params, list): + raise ValueError("shadow_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): + raise ValueError("shadow_params must all be Tensors") + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + if not isinstance(self.collected_params, list): + raise ValueError("collected_params must be a list") + if not all(isinstance(p, torch.Tensor) for p in self.collected_params): + raise ValueError("collected_params must all be Tensors") + if len(self.collected_params) != len(self.shadow_params): + raise ValueError("collected_params and shadow_params must have the same length") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 42f683f8..26acdd41 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -205,7 +205,7 @@ class ModelTesterMixin: model = self.model_class(**init_dict) model.to(torch_device) model.train() - ema_model = EMAModel(model, device=torch_device) + ema_model = EMAModel(model.parameters()) output = model(**inputs_dict) @@ -215,7 +215,7 @@ class ModelTesterMixin: noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device) loss = torch.nn.functional.mse_loss(output, noise) loss.backward() - ema_model.step(model) + ema_model.step(model.parameters()) def test_outputs_equivalence(self): def set_nan_tensor_to_zero(t):