Synchronize ranks for DDP
This commit is contained in:
parent
4572617ff9
commit
7102d313ac
|
@ -78,7 +78,7 @@ parser.add_argument('--shuffle', dest='shuffle', type=bool_t, default='True', he
|
|||
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.')
|
||||
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB')
|
||||
parser.add_argument('--fp16', dest='fp16', type=bool_t, default='False', help='Train in mixed precision')
|
||||
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
||||
parser.add_argument('--image_log_steps', type=int, default=500, help='Number of steps to log images at.')
|
||||
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
||||
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
|
||||
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
|
||||
|
@ -691,6 +691,13 @@ def main():
|
|||
unet = unet.to(device, dtype=torch.float32)
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
|
||||
unet = torch.nn.parallel.DistributedDataParallel(
|
||||
unet,
|
||||
device_ids=[rank],
|
||||
output_device=rank,
|
||||
gradient_as_bucket_view=True
|
||||
)
|
||||
|
||||
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
@ -701,6 +708,7 @@ def main():
|
|||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
"""
|
||||
optimizer = optimizer_cls(
|
||||
unet.parameters(),
|
||||
lr=args.lr,
|
||||
|
@ -708,13 +716,25 @@ def main():
|
|||
eps=args.adam_epsilon,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
)
|
||||
"""
|
||||
|
||||
noise_scheduler = DDPMScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule='scaled_linear',
|
||||
num_train_timesteps=1000,
|
||||
clip_sample=False
|
||||
# Create distributed optimizer
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
optimizer = ZeroRedundancyOptimizer(
|
||||
unet.parameters(),
|
||||
optimizer_class=optimizer_cls,
|
||||
parameters_as_bucket_view=True,
|
||||
lr=args.lr,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
eps=args.adam_epsilon,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
)
|
||||
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(
|
||||
args.model,
|
||||
subfolder='scheduler',
|
||||
use_auth_token=args.hf_token,
|
||||
)
|
||||
|
||||
# load dataset
|
||||
|
@ -743,8 +763,6 @@ def main():
|
|||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
||||
exit(0)
|
||||
|
||||
unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
||||
|
||||
# create ema
|
||||
if args.use_ema:
|
||||
ema_unet = EMAModel(unet.parameters())
|
||||
|
@ -786,8 +804,6 @@ def main():
|
|||
|
||||
if args.use_ema:
|
||||
ema_unet.restore(unet.parameters())
|
||||
# barrier
|
||||
torch.distributed.barrier()
|
||||
|
||||
# train!
|
||||
try:
|
||||
|
@ -823,10 +839,6 @@ def main():
|
|||
else:
|
||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
with torch.autocast('cuda', enabled=args.fp16):
|
||||
noise_pred = unet.module(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
|
@ -834,15 +846,20 @@ def main():
|
|||
else:
|
||||
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
with unet.join():
|
||||
# Predict the noise residual and compute loss
|
||||
with torch.autocast('cuda', enabled=args.fp16):
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# backprop and update
|
||||
scaler.scale(loss).backward()
|
||||
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# backprop and update
|
||||
scaler.scale(loss).backward()
|
||||
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Update EMA
|
||||
if args.use_ema:
|
||||
|
@ -875,11 +892,10 @@ def main():
|
|||
progress_bar.set_postfix(logs)
|
||||
run.log(logs, step=global_step)
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
if global_step % args.save_steps == 0 and global_step > 0:
|
||||
save_checkpoint(global_step)
|
||||
|
||||
if args.enableinference:
|
||||
if global_step % args.image_log_steps == 0:
|
||||
if global_step % args.image_log_steps == 0 and global_step > 0:
|
||||
if rank == 0:
|
||||
# get prompt from random batch
|
||||
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
||||
|
@ -935,7 +951,6 @@ def main():
|
|||
# cleanup so we don't run out of memory
|
||||
del pipeline
|
||||
gc.collect()
|
||||
torch.distributed.barrier()
|
||||
except Exception as e:
|
||||
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
|
||||
pass
|
||||
|
|
Loading…
Reference in New Issue