Fix EMA for multi-gpu training in the unconditional example (#1930)
* improve EMA * style * one EMA model * quality * fix tests * fix test * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * re organise the unconditional script * backwards compatibility * default to init values for some args * fix ort script * issubclass => isinstance * update state_dict * docstr * doc * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * use .to if device is passed * deprecate device * make flake happy * fix typo Co-authored-by: patil-suraj <surajp815@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
f354dd9e2f
commit
7c82a16fc1
|
@ -14,13 +14,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import copy
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Iterable, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -36,6 +35,7 @@ from accelerate.utils import set_seed
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
|
from diffusers.training_utils import EMAModel
|
||||||
from diffusers.utils import check_min_version
|
from diffusers.utils import check_min_version
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
from huggingface_hub import HfFolder, Repository, whoami
|
from huggingface_hub import HfFolder, Repository, whoami
|
||||||
|
@ -305,115 +305,6 @@ dataset_name_mapping = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
|
||||||
class EMAModel:
|
|
||||||
"""
|
|
||||||
Exponential Moving Average of models weights
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
|
|
||||||
parameters = list(parameters)
|
|
||||||
self.shadow_params = [p.clone().detach() for p in parameters]
|
|
||||||
|
|
||||||
self.collected_params = None
|
|
||||||
|
|
||||||
self.decay = decay
|
|
||||||
self.optimization_step = 0
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def step(self, parameters):
|
|
||||||
parameters = list(parameters)
|
|
||||||
|
|
||||||
self.optimization_step += 1
|
|
||||||
|
|
||||||
# Compute the decay factor for the exponential moving average.
|
|
||||||
value = (1 + self.optimization_step) / (10 + self.optimization_step)
|
|
||||||
one_minus_decay = 1 - min(self.decay, value)
|
|
||||||
|
|
||||||
for s_param, param in zip(self.shadow_params, parameters):
|
|
||||||
if param.requires_grad:
|
|
||||||
s_param.sub_(one_minus_decay * (s_param - param))
|
|
||||||
else:
|
|
||||||
s_param.copy_(param)
|
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
|
||||||
"""
|
|
||||||
Copy current averaged parameters into given collection of parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
|
||||||
updated with the stored moving averages. If `None`, the
|
|
||||||
parameters with which this `ExponentialMovingAverage` was
|
|
||||||
initialized will be used.
|
|
||||||
"""
|
|
||||||
parameters = list(parameters)
|
|
||||||
for s_param, param in zip(self.shadow_params, parameters):
|
|
||||||
param.data.copy_(s_param.data)
|
|
||||||
|
|
||||||
def to(self, device=None, dtype=None) -> None:
|
|
||||||
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device: like `device` argument to `torch.Tensor.to`
|
|
||||||
"""
|
|
||||||
# .to() on the tensors handles None correctly
|
|
||||||
self.shadow_params = [
|
|
||||||
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
|
||||||
for p in self.shadow_params
|
|
||||||
]
|
|
||||||
|
|
||||||
def state_dict(self) -> dict:
|
|
||||||
r"""
|
|
||||||
Returns the state of the ExponentialMovingAverage as a dict.
|
|
||||||
This method is used by accelerate during checkpointing to save the ema state dict.
|
|
||||||
"""
|
|
||||||
# Following PyTorch conventions, references to tensors are returned:
|
|
||||||
# "returns a reference to the state and not its copy!" -
|
|
||||||
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
|
||||||
return {
|
|
||||||
"decay": self.decay,
|
|
||||||
"optimization_step": self.optimization_step,
|
|
||||||
"shadow_params": self.shadow_params,
|
|
||||||
"collected_params": self.collected_params,
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: dict) -> None:
|
|
||||||
r"""
|
|
||||||
Loads the ExponentialMovingAverage state.
|
|
||||||
This method is used by accelerate during checkpointing to save the ema state dict.
|
|
||||||
Args:
|
|
||||||
state_dict (dict): EMA state. Should be an object returned
|
|
||||||
from a call to :meth:`state_dict`.
|
|
||||||
"""
|
|
||||||
# deepcopy, to be consistent with module API
|
|
||||||
state_dict = copy.deepcopy(state_dict)
|
|
||||||
|
|
||||||
self.decay = state_dict["decay"]
|
|
||||||
if self.decay < 0.0 or self.decay > 1.0:
|
|
||||||
raise ValueError("Decay must be between 0 and 1")
|
|
||||||
|
|
||||||
self.optimization_step = state_dict["optimization_step"]
|
|
||||||
if not isinstance(self.optimization_step, int):
|
|
||||||
raise ValueError("Invalid optimization_step")
|
|
||||||
|
|
||||||
self.shadow_params = state_dict["shadow_params"]
|
|
||||||
if not isinstance(self.shadow_params, list):
|
|
||||||
raise ValueError("shadow_params must be a list")
|
|
||||||
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
|
||||||
raise ValueError("shadow_params must all be Tensors")
|
|
||||||
|
|
||||||
self.collected_params = state_dict["collected_params"]
|
|
||||||
if self.collected_params is not None:
|
|
||||||
if not isinstance(self.collected_params, list):
|
|
||||||
raise ValueError("collected_params must be a list")
|
|
||||||
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
|
|
||||||
raise ValueError("collected_params must all be Tensors")
|
|
||||||
if len(self.collected_params) != len(self.shadow_params):
|
|
||||||
raise ValueError("collected_params and shadow_params must have the same length")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||||
|
|
|
@ -39,6 +39,7 @@ accelerate launch train_unconditional.py \
|
||||||
--train_batch_size=16 \
|
--train_batch_size=16 \
|
||||||
--num_epochs=100 \
|
--num_epochs=100 \
|
||||||
--gradient_accumulation_steps=1 \
|
--gradient_accumulation_steps=1 \
|
||||||
|
--use_ema \
|
||||||
--learning_rate=1e-4 \
|
--learning_rate=1e-4 \
|
||||||
--lr_warmup_steps=500 \
|
--lr_warmup_steps=500 \
|
||||||
--mixed_precision=no \
|
--mixed_precision=no \
|
||||||
|
@ -63,6 +64,7 @@ accelerate launch train_unconditional.py \
|
||||||
--train_batch_size=16 \
|
--train_batch_size=16 \
|
||||||
--num_epochs=100 \
|
--num_epochs=100 \
|
||||||
--gradient_accumulation_steps=1 \
|
--gradient_accumulation_steps=1 \
|
||||||
|
--use_ema \
|
||||||
--learning_rate=1e-4 \
|
--learning_rate=1e-4 \
|
||||||
--lr_warmup_steps=500 \
|
--lr_warmup_steps=500 \
|
||||||
--mixed_precision=no \
|
--mixed_precision=no \
|
||||||
|
@ -150,6 +152,7 @@ accelerate launch train_unconditional_ort.py \
|
||||||
--dataset_name="huggan/flowers-102-categories" \
|
--dataset_name="huggan/flowers-102-categories" \
|
||||||
--resolution=64 \
|
--resolution=64 \
|
||||||
--output_dir="ddpm-ema-flowers-64" \
|
--output_dir="ddpm-ema-flowers-64" \
|
||||||
|
--use_ema \
|
||||||
--train_batch_size=16 \
|
--train_batch_size=16 \
|
||||||
--num_epochs=1 \
|
--num_epochs=1 \
|
||||||
--gradient_accumulation_steps=1 \
|
--gradient_accumulation_steps=1 \
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -8,6 +10,8 @@ from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import datasets
|
||||||
|
import diffusers
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
@ -29,10 +33,10 @@ from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.10.0.dev0")
|
check_min_version("0.12.0.dev0")
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|
||||||
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
||||||
|
@ -156,7 +160,6 @@ def parse_args():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_ema",
|
"--use_ema",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=True,
|
|
||||||
help="Whether to use Exponential Moving Average for the final model weights.",
|
help="Whether to use Exponential Moving Average for the final model weights.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
|
parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
|
||||||
|
@ -255,6 +258,7 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||||
|
|
||||||
accelerator = Accelerator(
|
accelerator = Accelerator(
|
||||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||||
mixed_precision=args.mixed_precision,
|
mixed_precision=args.mixed_precision,
|
||||||
|
@ -262,6 +266,38 @@ def main(args):
|
||||||
logging_dir=logging_dir,
|
logging_dir=logging_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make one log on every process with the configuration for debugging.
|
||||||
|
logging.basicConfig(
|
||||||
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
level=logging.INFO,
|
||||||
|
)
|
||||||
|
logger.info(accelerator.state, main_process_only=False)
|
||||||
|
if accelerator.is_local_main_process:
|
||||||
|
datasets.utils.logging.set_verbosity_warning()
|
||||||
|
diffusers.utils.logging.set_verbosity_info()
|
||||||
|
else:
|
||||||
|
datasets.utils.logging.set_verbosity_error()
|
||||||
|
diffusers.utils.logging.set_verbosity_error()
|
||||||
|
|
||||||
|
# Handle the repository creation
|
||||||
|
if accelerator.is_main_process:
|
||||||
|
if args.push_to_hub:
|
||||||
|
if args.hub_model_id is None:
|
||||||
|
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||||
|
else:
|
||||||
|
repo_name = args.hub_model_id
|
||||||
|
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||||
|
|
||||||
|
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||||
|
if "step_*" not in gitignore:
|
||||||
|
gitignore.write("step_*\n")
|
||||||
|
if "epoch_*" not in gitignore:
|
||||||
|
gitignore.write("epoch_*\n")
|
||||||
|
elif args.output_dir is not None:
|
||||||
|
os.makedirs(args.output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Initialize the model
|
||||||
model = UNet2DModel(
|
model = UNet2DModel(
|
||||||
sample_size=args.resolution,
|
sample_size=args.resolution,
|
||||||
in_channels=3,
|
in_channels=3,
|
||||||
|
@ -285,8 +321,19 @@ def main(args):
|
||||||
"UpBlock2D",
|
"UpBlock2D",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
|
||||||
|
|
||||||
|
# Create EMA for the model.
|
||||||
|
if args.use_ema:
|
||||||
|
ema_model = EMAModel(
|
||||||
|
model.parameters(),
|
||||||
|
decay=args.ema_max_decay,
|
||||||
|
use_ema_warmup=True,
|
||||||
|
inv_gamma=args.ema_inv_gamma,
|
||||||
|
power=args.ema_power,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the scheduler
|
||||||
|
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||||
if accepts_prediction_type:
|
if accepts_prediction_type:
|
||||||
noise_scheduler = DDPMScheduler(
|
noise_scheduler = DDPMScheduler(
|
||||||
num_train_timesteps=args.ddpm_num_steps,
|
num_train_timesteps=args.ddpm_num_steps,
|
||||||
|
@ -296,6 +343,7 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
noise_scheduler = DDPMScheduler(num_train_timesteps=args.ddpm_num_steps, beta_schedule=args.ddpm_beta_schedule)
|
||||||
|
|
||||||
|
# Initialize the optimizer
|
||||||
optimizer = torch.optim.AdamW(
|
optimizer = torch.optim.AdamW(
|
||||||
model.parameters(),
|
model.parameters(),
|
||||||
lr=args.learning_rate,
|
lr=args.learning_rate,
|
||||||
|
@ -304,16 +352,11 @@ def main(args):
|
||||||
eps=args.adam_epsilon,
|
eps=args.adam_epsilon,
|
||||||
)
|
)
|
||||||
|
|
||||||
augmentations = Compose(
|
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||||
[
|
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||||
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
|
|
||||||
CenterCrop(args.resolution),
|
|
||||||
RandomHorizontalFlip(),
|
|
||||||
ToTensor(),
|
|
||||||
Normalize([0.5], [0.5]),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||||
|
# download the dataset.
|
||||||
if args.dataset_name is not None:
|
if args.dataset_name is not None:
|
||||||
dataset = load_dataset(
|
dataset = load_dataset(
|
||||||
args.dataset_name,
|
args.dataset_name,
|
||||||
|
@ -323,6 +366,19 @@ def main(args):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
|
dataset = load_dataset("imagefolder", data_dir=args.train_data_dir, cache_dir=args.cache_dir, split="train")
|
||||||
|
# See more about loading custom images at
|
||||||
|
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
||||||
|
|
||||||
|
# Preprocessing the datasets and DataLoaders creation.
|
||||||
|
augmentations = Compose(
|
||||||
|
[
|
||||||
|
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
|
||||||
|
CenterCrop(args.resolution),
|
||||||
|
RandomHorizontalFlip(),
|
||||||
|
ToTensor(),
|
||||||
|
Normalize([0.5], [0.5]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
def transforms(examples):
|
def transforms(examples):
|
||||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||||
|
@ -335,6 +391,7 @@ def main(args):
|
||||||
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize the learning rate scheduler
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
args.lr_scheduler,
|
args.lr_scheduler,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
|
@ -342,44 +399,37 @@ def main(args):
|
||||||
num_training_steps=(len(train_dataloader) * args.num_epochs),
|
num_training_steps=(len(train_dataloader) * args.num_epochs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Prepare everything with our `accelerator`.
|
||||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
model, optimizer, train_dataloader, lr_scheduler
|
model, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
accelerator.register_for_checkpointing(lr_scheduler)
|
|
||||||
|
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
if args.use_ema:
|
||||||
|
accelerator.register_for_checkpointing(ema_model)
|
||||||
ema_model = EMAModel(
|
ema_model.to(accelerator.device)
|
||||||
accelerator.unwrap_model(model),
|
|
||||||
inv_gamma=args.ema_inv_gamma,
|
|
||||||
power=args.ema_power,
|
|
||||||
max_value=args.ema_max_decay,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle the repository creation
|
|
||||||
if accelerator.is_main_process:
|
|
||||||
if args.push_to_hub:
|
|
||||||
if args.hub_model_id is None:
|
|
||||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
|
||||||
else:
|
|
||||||
repo_name = args.hub_model_id
|
|
||||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
|
||||||
|
|
||||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
|
||||||
if "step_*" not in gitignore:
|
|
||||||
gitignore.write("step_*\n")
|
|
||||||
if "epoch_*" not in gitignore:
|
|
||||||
gitignore.write("epoch_*\n")
|
|
||||||
elif args.output_dir is not None:
|
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
|
# We need to initialize the trackers we use, and also store our configuration.
|
||||||
|
# The trackers initializes automatically on the main process.
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
run = os.path.split(__file__)[-1].split(".")[0]
|
run = os.path.split(__file__)[-1].split(".")[0]
|
||||||
accelerator.init_trackers(run)
|
accelerator.init_trackers(run)
|
||||||
|
|
||||||
|
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||||
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
max_train_steps = args.num_epochs * num_update_steps_per_epoch
|
||||||
|
|
||||||
|
logger.info("***** Running training *****")
|
||||||
|
logger.info(f" Num examples = {len(dataset)}")
|
||||||
|
logger.info(f" Num Epochs = {args.num_epochs}")
|
||||||
|
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||||
|
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||||
|
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||||
|
logger.info(f" Total optimization steps = {max_train_steps}")
|
||||||
|
|
||||||
global_step = 0
|
global_step = 0
|
||||||
first_epoch = 0
|
first_epoch = 0
|
||||||
|
|
||||||
|
# Potentially load in the weights and states from a previous save
|
||||||
if args.resume_from_checkpoint:
|
if args.resume_from_checkpoint:
|
||||||
if args.resume_from_checkpoint != "latest":
|
if args.resume_from_checkpoint != "latest":
|
||||||
path = os.path.basename(args.resume_from_checkpoint)
|
path = os.path.basename(args.resume_from_checkpoint)
|
||||||
|
@ -397,6 +447,7 @@ def main(args):
|
||||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
resume_step = resume_global_step % num_update_steps_per_epoch
|
||||||
|
|
||||||
|
# Train!
|
||||||
for epoch in range(first_epoch, args.num_epochs):
|
for epoch in range(first_epoch, args.num_epochs):
|
||||||
model.train()
|
model.train()
|
||||||
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
|
progress_bar = tqdm(total=num_update_steps_per_epoch, disable=not accelerator.is_local_main_process)
|
||||||
|
@ -445,12 +496,12 @@ def main(args):
|
||||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
if args.use_ema:
|
|
||||||
ema_model.step(model)
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
|
if args.use_ema:
|
||||||
|
ema_model.step(model.parameters())
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
||||||
|
@ -472,8 +523,11 @@ def main(args):
|
||||||
# Generate sample images for visual inspection
|
# Generate sample images for visual inspection
|
||||||
if accelerator.is_main_process:
|
if accelerator.is_main_process:
|
||||||
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
|
||||||
|
unet = copy.deepcopy(accelerator.unwrap_model(model))
|
||||||
|
if args.use_ema:
|
||||||
|
ema_model.copy_to(unet.parameters())
|
||||||
pipeline = DDPMPipeline(
|
pipeline = DDPMPipeline(
|
||||||
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
|
unet=unet,
|
||||||
scheduler=noise_scheduler,
|
scheduler=noise_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -498,7 +552,6 @@ def main(args):
|
||||||
pipeline.save_pretrained(args.output_dir)
|
pipeline.save_pretrained(args.output_dir)
|
||||||
if args.push_to_hub:
|
if args.push_to_hub:
|
||||||
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
|
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
|
||||||
accelerator.wait_for_everyone()
|
|
||||||
|
|
||||||
accelerator.end_training()
|
accelerator.end_training()
|
||||||
|
|
||||||
|
|
|
@ -157,7 +157,6 @@ def parse_args():
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--use_ema",
|
"--use_ema",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
default=True,
|
|
||||||
help="Whether to use Exponential Moving Average for the final model weights.",
|
help="Whether to use Exponential Moving Average for the final model weights.",
|
||||||
)
|
)
|
||||||
parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
|
parser.add_argument("--ema_inv_gamma", type=float, default=1.0, help="The inverse gamma value for the EMA decay.")
|
||||||
|
@ -287,8 +286,17 @@ def main(args):
|
||||||
"UpBlock2D",
|
"UpBlock2D",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
|
||||||
|
|
||||||
|
if args.use_ema:
|
||||||
|
ema_model = EMAModel(
|
||||||
|
model.parameters(),
|
||||||
|
decay=args.ema_max_decay,
|
||||||
|
use_ema_warmup=True,
|
||||||
|
inv_gamma=args.ema_inv_gamma,
|
||||||
|
power=args.ema_power,
|
||||||
|
)
|
||||||
|
|
||||||
|
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
|
||||||
if accepts_prediction_type:
|
if accepts_prediction_type:
|
||||||
noise_scheduler = DDPMScheduler(
|
noise_scheduler = DDPMScheduler(
|
||||||
num_train_timesteps=args.ddpm_num_steps,
|
num_train_timesteps=args.ddpm_num_steps,
|
||||||
|
@ -347,17 +355,13 @@ def main(args):
|
||||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
model, optimizer, train_dataloader, lr_scheduler
|
model, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
accelerator.register_for_checkpointing(lr_scheduler)
|
|
||||||
|
if args.use_ema:
|
||||||
|
accelerator.register_for_checkpointing(ema_model)
|
||||||
|
ema_model.to(accelerator.device)
|
||||||
|
|
||||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||||
|
|
||||||
ema_model = EMAModel(
|
|
||||||
accelerator.unwrap_model(model),
|
|
||||||
inv_gamma=args.ema_inv_gamma,
|
|
||||||
power=args.ema_power,
|
|
||||||
max_value=args.ema_max_decay,
|
|
||||||
)
|
|
||||||
|
|
||||||
model = ORTModule(model)
|
model = ORTModule(model)
|
||||||
|
|
||||||
# Handle the repository creation
|
# Handle the repository creation
|
||||||
|
@ -448,7 +452,7 @@ def main(args):
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
ema_model.step(model)
|
ema_model.step(model.parameters())
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
import copy
|
import copy
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
|
from typing import Iterable, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from .utils import deprecate
|
||||||
|
|
||||||
|
|
||||||
def enable_full_determinism(seed: int):
|
def enable_full_determinism(seed: int):
|
||||||
"""
|
"""
|
||||||
|
@ -39,6 +42,7 @@ def set_seed(seed: int):
|
||||||
# ^^ safe to call this function even if cuda is not available
|
# ^^ safe to call this function even if cuda is not available
|
||||||
|
|
||||||
|
|
||||||
|
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
||||||
class EMAModel:
|
class EMAModel:
|
||||||
"""
|
"""
|
||||||
Exponential Moving Average of models weights
|
Exponential Moving Average of models weights
|
||||||
|
@ -46,81 +50,224 @@ class EMAModel:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
parameters: Iterable[torch.nn.Parameter],
|
||||||
update_after_step=0,
|
decay: float = 0.9999,
|
||||||
inv_gamma=1.0,
|
min_decay: float = 0.0,
|
||||||
power=2 / 3,
|
update_after_step: int = 0,
|
||||||
min_value=0.0,
|
use_ema_warmup: bool = False,
|
||||||
max_value=0.9999,
|
inv_gamma: Union[float, int] = 1.0,
|
||||||
device=None,
|
power: Union[float, int] = 2 / 3,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
Args:
|
||||||
|
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
||||||
|
decay (float): The decay factor for the exponential moving average.
|
||||||
|
min_decay (float): The minimum decay factor for the exponential moving average.
|
||||||
|
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
||||||
|
use_ema_warmup (bool): Whether to use EMA warmup.
|
||||||
|
inv_gamma (float):
|
||||||
|
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
||||||
|
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
||||||
|
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
|
||||||
|
weights will be stored on CPU.
|
||||||
|
|
||||||
@crowsonkb's notes on EMA Warmup:
|
@crowsonkb's notes on EMA Warmup:
|
||||||
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
||||||
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
||||||
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
||||||
at 215.4k steps).
|
at 215.4k steps).
|
||||||
|
|
||||||
Args:
|
|
||||||
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
|
|
||||||
power (float): Exponential factor of EMA warmup. Default: 2/3.
|
|
||||||
min_value (float): The minimum EMA decay rate. Default: 0.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.averaged_model = copy.deepcopy(model).eval()
|
if isinstance(parameters, torch.nn.Module):
|
||||||
self.averaged_model.requires_grad_(False)
|
deprecation_message = (
|
||||||
|
"Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
|
||||||
|
"Please pass the parameters of the module instead."
|
||||||
|
)
|
||||||
|
deprecate(
|
||||||
|
"passing a `torch.nn.Module` to `ExponentialMovingAverage`",
|
||||||
|
"1.0.0",
|
||||||
|
deprecation_message,
|
||||||
|
standard_warn=False,
|
||||||
|
)
|
||||||
|
parameters = parameters.parameters()
|
||||||
|
|
||||||
|
# set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
|
||||||
|
use_ema_warmup = True
|
||||||
|
|
||||||
|
if kwargs.get("max_value", None) is not None:
|
||||||
|
deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
|
||||||
|
deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
decay = kwargs["max_value"]
|
||||||
|
|
||||||
|
if kwargs.get("min_value", None) is not None:
|
||||||
|
deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
|
||||||
|
deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
min_decay = kwargs["min_value"]
|
||||||
|
|
||||||
|
parameters = list(parameters)
|
||||||
|
self.shadow_params = [p.clone().detach() for p in parameters]
|
||||||
|
|
||||||
|
if kwargs.get("device", None) is not None:
|
||||||
|
deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
|
||||||
|
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
self.to(device=kwargs["device"])
|
||||||
|
|
||||||
|
self.collected_params = None
|
||||||
|
|
||||||
|
self.decay = decay
|
||||||
|
self.min_decay = min_decay
|
||||||
self.update_after_step = update_after_step
|
self.update_after_step = update_after_step
|
||||||
|
self.use_ema_warmup = use_ema_warmup
|
||||||
self.inv_gamma = inv_gamma
|
self.inv_gamma = inv_gamma
|
||||||
self.power = power
|
self.power = power
|
||||||
self.min_value = min_value
|
|
||||||
self.max_value = max_value
|
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
self.averaged_model = self.averaged_model.to(device=device)
|
|
||||||
|
|
||||||
self.decay = 0.0
|
|
||||||
self.optimization_step = 0
|
self.optimization_step = 0
|
||||||
|
|
||||||
def get_decay(self, optimization_step):
|
def get_decay(self, optimization_step: int) -> float:
|
||||||
"""
|
"""
|
||||||
Compute the decay factor for the exponential moving average.
|
Compute the decay factor for the exponential moving average.
|
||||||
"""
|
"""
|
||||||
step = max(0, optimization_step - self.update_after_step - 1)
|
step = max(0, optimization_step - self.update_after_step - 1)
|
||||||
value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
|
||||||
|
|
||||||
if step <= 0:
|
if step <= 0:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
return max(self.min_value, min(value, self.max_value))
|
if self.use_ema_warmup:
|
||||||
|
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
||||||
|
else:
|
||||||
|
cur_decay_value = (1 + step) / (10 + step)
|
||||||
|
|
||||||
|
cur_decay_value = min(cur_decay_value, self.decay)
|
||||||
|
# make sure decay is not smaller than min_decay
|
||||||
|
cur_decay_value = max(cur_decay_value, self.min_decay)
|
||||||
|
return cur_decay_value
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def step(self, new_model):
|
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
||||||
ema_state_dict = {}
|
if isinstance(parameters, torch.nn.Module):
|
||||||
ema_params = self.averaged_model.state_dict()
|
deprecation_message = (
|
||||||
|
"Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
|
||||||
|
"Please pass the parameters of the module instead."
|
||||||
|
)
|
||||||
|
deprecate(
|
||||||
|
"passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
|
||||||
|
"1.0.0",
|
||||||
|
deprecation_message,
|
||||||
|
standard_warn=False,
|
||||||
|
)
|
||||||
|
parameters = parameters.parameters()
|
||||||
|
|
||||||
self.decay = self.get_decay(self.optimization_step)
|
parameters = list(parameters)
|
||||||
|
|
||||||
for key, param in new_model.named_parameters():
|
|
||||||
if isinstance(param, dict):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
ema_param = ema_params[key]
|
|
||||||
except KeyError:
|
|
||||||
ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
|
|
||||||
ema_params[key] = ema_param
|
|
||||||
|
|
||||||
if not param.requires_grad:
|
|
||||||
ema_params[key].copy_(param.to(dtype=ema_param.dtype).data)
|
|
||||||
ema_param = ema_params[key]
|
|
||||||
else:
|
|
||||||
ema_param.mul_(self.decay)
|
|
||||||
ema_param.add_(param.data.to(dtype=ema_param.dtype), alpha=1 - self.decay)
|
|
||||||
|
|
||||||
ema_state_dict[key] = ema_param
|
|
||||||
|
|
||||||
for key, param in new_model.named_buffers():
|
|
||||||
ema_state_dict[key] = param
|
|
||||||
|
|
||||||
self.averaged_model.load_state_dict(ema_state_dict, strict=False)
|
|
||||||
self.optimization_step += 1
|
self.optimization_step += 1
|
||||||
|
|
||||||
|
# Compute the decay factor for the exponential moving average.
|
||||||
|
decay = self.get_decay(self.optimization_step)
|
||||||
|
one_minus_decay = 1 - decay
|
||||||
|
|
||||||
|
for s_param, param in zip(self.shadow_params, parameters):
|
||||||
|
if param.requires_grad:
|
||||||
|
s_param.sub_(one_minus_decay * (s_param - param))
|
||||||
|
else:
|
||||||
|
s_param.copy_(param)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
||||||
|
"""
|
||||||
|
Copy current averaged parameters into given collection of parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||||
|
updated with the stored moving averages. If `None`, the parameters with which this
|
||||||
|
`ExponentialMovingAverage` was initialized will be used.
|
||||||
|
"""
|
||||||
|
parameters = list(parameters)
|
||||||
|
for s_param, param in zip(self.shadow_params, parameters):
|
||||||
|
param.data.copy_(s_param.data)
|
||||||
|
|
||||||
|
def to(self, device=None, dtype=None) -> None:
|
||||||
|
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device: like `device` argument to `torch.Tensor.to`
|
||||||
|
"""
|
||||||
|
# .to() on the tensors handles None correctly
|
||||||
|
self.shadow_params = [
|
||||||
|
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
|
||||||
|
for p in self.shadow_params
|
||||||
|
]
|
||||||
|
|
||||||
|
def state_dict(self) -> dict:
|
||||||
|
r"""
|
||||||
|
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
||||||
|
checkpointing to save the ema state dict.
|
||||||
|
"""
|
||||||
|
# Following PyTorch conventions, references to tensors are returned:
|
||||||
|
# "returns a reference to the state and not its copy!" -
|
||||||
|
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
||||||
|
return {
|
||||||
|
"decay": self.decay,
|
||||||
|
"min_decay": self.decay,
|
||||||
|
"optimization_step": self.optimization_step,
|
||||||
|
"update_after_step": self.update_after_step,
|
||||||
|
"use_ema_warmup": self.use_ema_warmup,
|
||||||
|
"inv_gamma": self.inv_gamma,
|
||||||
|
"power": self.power,
|
||||||
|
"shadow_params": self.shadow_params,
|
||||||
|
"collected_params": self.collected_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: dict) -> None:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
||||||
|
ema state dict.
|
||||||
|
state_dict (dict): EMA state. Should be an object returned
|
||||||
|
from a call to :meth:`state_dict`.
|
||||||
|
"""
|
||||||
|
# deepcopy, to be consistent with module API
|
||||||
|
state_dict = copy.deepcopy(state_dict)
|
||||||
|
|
||||||
|
self.decay = state_dict.get("decay", self.decay)
|
||||||
|
if self.decay < 0.0 or self.decay > 1.0:
|
||||||
|
raise ValueError("Decay must be between 0 and 1")
|
||||||
|
|
||||||
|
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
||||||
|
if not isinstance(self.min_decay, float):
|
||||||
|
raise ValueError("Invalid min_decay")
|
||||||
|
|
||||||
|
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
||||||
|
if not isinstance(self.optimization_step, int):
|
||||||
|
raise ValueError("Invalid optimization_step")
|
||||||
|
|
||||||
|
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
||||||
|
if not isinstance(self.update_after_step, int):
|
||||||
|
raise ValueError("Invalid update_after_step")
|
||||||
|
|
||||||
|
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
||||||
|
if not isinstance(self.use_ema_warmup, bool):
|
||||||
|
raise ValueError("Invalid use_ema_warmup")
|
||||||
|
|
||||||
|
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
||||||
|
if not isinstance(self.inv_gamma, (float, int)):
|
||||||
|
raise ValueError("Invalid inv_gamma")
|
||||||
|
|
||||||
|
self.power = state_dict["power"].get("power", self.power)
|
||||||
|
if not isinstance(self.power, (float, int)):
|
||||||
|
raise ValueError("Invalid power")
|
||||||
|
|
||||||
|
self.shadow_params = state_dict["shadow_params"]
|
||||||
|
if not isinstance(self.shadow_params, list):
|
||||||
|
raise ValueError("shadow_params must be a list")
|
||||||
|
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
||||||
|
raise ValueError("shadow_params must all be Tensors")
|
||||||
|
|
||||||
|
self.collected_params = state_dict["collected_params"]
|
||||||
|
if self.collected_params is not None:
|
||||||
|
if not isinstance(self.collected_params, list):
|
||||||
|
raise ValueError("collected_params must be a list")
|
||||||
|
if not all(isinstance(p, torch.Tensor) for p in self.collected_params):
|
||||||
|
raise ValueError("collected_params must all be Tensors")
|
||||||
|
if len(self.collected_params) != len(self.shadow_params):
|
||||||
|
raise ValueError("collected_params and shadow_params must have the same length")
|
||||||
|
|
|
@ -205,7 +205,7 @@ class ModelTesterMixin:
|
||||||
model = self.model_class(**init_dict)
|
model = self.model_class(**init_dict)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
ema_model = EMAModel(model, device=torch_device)
|
ema_model = EMAModel(model.parameters())
|
||||||
|
|
||||||
output = model(**inputs_dict)
|
output = model(**inputs_dict)
|
||||||
|
|
||||||
|
@ -215,7 +215,7 @@ class ModelTesterMixin:
|
||||||
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
|
noise = torch.randn((inputs_dict["sample"].shape[0],) + self.output_shape).to(torch_device)
|
||||||
loss = torch.nn.functional.mse_loss(output, noise)
|
loss = torch.nn.functional.mse_loss(output, noise)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
ema_model.step(model)
|
ema_model.step(model.parameters())
|
||||||
|
|
||||||
def test_outputs_equivalence(self):
|
def test_outputs_equivalence(self):
|
||||||
def set_nan_tensor_to_zero(t):
|
def set_nan_tensor_to_zero(t):
|
||||||
|
|
Loading…
Reference in New Issue