Synchronize ranks for DDP
This commit is contained in:
parent
4572617ff9
commit
7102d313ac
|
@ -78,7 +78,7 @@ parser.add_argument('--shuffle', dest='shuffle', type=bool_t, default='True', he
|
||||||
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.')
|
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.')
|
||||||
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB')
|
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB')
|
||||||
parser.add_argument('--fp16', dest='fp16', type=bool_t, default='False', help='Train in mixed precision')
|
parser.add_argument('--fp16', dest='fp16', type=bool_t, default='False', help='Train in mixed precision')
|
||||||
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
parser.add_argument('--image_log_steps', type=int, default=500, help='Number of steps to log images at.')
|
||||||
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
||||||
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
|
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
|
||||||
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
|
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
|
||||||
|
@ -691,6 +691,13 @@ def main():
|
||||||
unet = unet.to(device, dtype=torch.float32)
|
unet = unet.to(device, dtype=torch.float32)
|
||||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||||
|
|
||||||
|
unet = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
unet,
|
||||||
|
device_ids=[rank],
|
||||||
|
output_device=rank,
|
||||||
|
gradient_as_bucket_view=True
|
||||||
|
)
|
||||||
|
|
||||||
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
@ -701,6 +708,7 @@ def main():
|
||||||
else:
|
else:
|
||||||
optimizer_cls = torch.optim.AdamW
|
optimizer_cls = torch.optim.AdamW
|
||||||
|
|
||||||
|
"""
|
||||||
optimizer = optimizer_cls(
|
optimizer = optimizer_cls(
|
||||||
unet.parameters(),
|
unet.parameters(),
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
|
@ -708,13 +716,25 @@ def main():
|
||||||
eps=args.adam_epsilon,
|
eps=args.adam_epsilon,
|
||||||
weight_decay=args.adam_weight_decay,
|
weight_decay=args.adam_weight_decay,
|
||||||
)
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
noise_scheduler = DDPMScheduler(
|
# Create distributed optimizer
|
||||||
beta_start=0.00085,
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
beta_end=0.012,
|
optimizer = ZeroRedundancyOptimizer(
|
||||||
beta_schedule='scaled_linear',
|
unet.parameters(),
|
||||||
num_train_timesteps=1000,
|
optimizer_class=optimizer_cls,
|
||||||
clip_sample=False
|
parameters_as_bucket_view=True,
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(args.adam_beta1, args.adam_beta2),
|
||||||
|
eps=args.adam_epsilon,
|
||||||
|
weight_decay=args.adam_weight_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
noise_scheduler = DDPMScheduler.from_pretrained(
|
||||||
|
args.model,
|
||||||
|
subfolder='scheduler',
|
||||||
|
use_auth_token=args.hf_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
# load dataset
|
# load dataset
|
||||||
|
@ -743,8 +763,6 @@ def main():
|
||||||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
|
||||||
|
|
||||||
# create ema
|
# create ema
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
ema_unet = EMAModel(unet.parameters())
|
ema_unet = EMAModel(unet.parameters())
|
||||||
|
@ -786,8 +804,6 @@ def main():
|
||||||
|
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
ema_unet.restore(unet.parameters())
|
ema_unet.restore(unet.parameters())
|
||||||
# barrier
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
# train!
|
# train!
|
||||||
try:
|
try:
|
||||||
|
@ -823,10 +839,6 @@ def main():
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||||
|
|
||||||
# Predict the noise residual and compute loss
|
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
|
||||||
noise_pred = unet.module(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
||||||
|
|
||||||
if noise_scheduler.config.prediction_type == "epsilon":
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
target = noise
|
target = noise
|
||||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||||
|
@ -834,15 +846,20 @@ def main():
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
|
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
with unet.join():
|
||||||
|
# Predict the noise residual and compute loss
|
||||||
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
# backprop and update
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||||
scaler.scale(loss).backward()
|
|
||||||
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
# backprop and update
|
||||||
scaler.step(optimizer)
|
scaler.scale(loss).backward()
|
||||||
scaler.update()
|
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
||||||
lr_scheduler.step()
|
scaler.step(optimizer)
|
||||||
optimizer.zero_grad()
|
scaler.update()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Update EMA
|
# Update EMA
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
|
@ -875,11 +892,10 @@ def main():
|
||||||
progress_bar.set_postfix(logs)
|
progress_bar.set_postfix(logs)
|
||||||
run.log(logs, step=global_step)
|
run.log(logs, step=global_step)
|
||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0 and global_step > 0:
|
||||||
save_checkpoint(global_step)
|
save_checkpoint(global_step)
|
||||||
|
|
||||||
if args.enableinference:
|
if args.enableinference:
|
||||||
if global_step % args.image_log_steps == 0:
|
if global_step % args.image_log_steps == 0 and global_step > 0:
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# get prompt from random batch
|
# get prompt from random batch
|
||||||
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
||||||
|
@ -935,7 +951,6 @@ def main():
|
||||||
# cleanup so we don't run out of memory
|
# cleanup so we don't run out of memory
|
||||||
del pipeline
|
del pipeline
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.distributed.barrier()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
|
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue