Fix distributed training

This commit is contained in:
Anthony Mercurio 2022-11-29 18:01:17 -07:00 committed by GitHub
parent 5f0a952eff
commit 8decb0bc7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 15 deletions

View File

@ -777,9 +777,7 @@ def main():
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler( scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token),
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )
@ -828,10 +826,20 @@ def main():
with torch.autocast('cuda', enabled=args.fp16): with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
# Backprop and all reduce # All-reduce loss, backprop, and update weights
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss = loss / world_size
scaler.scale(loss).backward() scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
lr_scheduler.step() lr_scheduler.step()
@ -849,14 +857,11 @@ def main():
world_images_per_second = rank_images_per_second * world_size world_images_per_second = rank_images_per_second * world_size
samples_seen = global_step * args.batch_size * world_size samples_seen = global_step * args.batch_size * world_size
# All reduce loss
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
if rank == 0: if rank == 0:
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
logs = { logs = {
"train/loss": loss.detach().item() / world_size, "train/loss": loss.detach().item(),
"train/lr": lr_scheduler.get_last_lr()[0], "train/lr": lr_scheduler.get_last_lr()[0],
"train/epoch": epoch, "train/epoch": epoch,
"train/step": global_step, "train/step": global_step,
@ -878,14 +883,10 @@ def main():
if args.image_log_scheduler == 'DDIMScheduler': if args.image_log_scheduler == 'DDIMScheduler':
print('using DDIMScheduler scheduler') print('using DDIMScheduler scheduler')
scheduler = DDIMScheduler( scheduler = DDIMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
else: else:
print('using PNDMScheduler scheduler') print('using PNDMScheduler scheduler')
scheduler=PNDMScheduler( scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline(
text_encoder=text_encoder, text_encoder=text_encoder,