Fix distributed training

This commit is contained in:
Anthony Mercurio 2022-11-29 18:01:17 -07:00 committed by GitHub
parent 5f0a952eff
commit 8decb0bc7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 16 additions and 15 deletions

View File

@ -777,9 +777,7 @@ def main():
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
@ -828,10 +826,20 @@ def main():
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
# Backprop and all reduce
# All-reduce loss, backprop, and update weights
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss = loss / world_size
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
@ -849,14 +857,11 @@ def main():
world_images_per_second = rank_images_per_second * world_size
samples_seen = global_step * args.batch_size * world_size
# All reduce loss
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
if rank == 0:
progress_bar.update(1)
global_step += 1
logs = {
"train/loss": loss.detach().item() / world_size,
"train/loss": loss.detach().item(),
"train/lr": lr_scheduler.get_last_lr()[0],
"train/epoch": epoch,
"train/step": global_step,
@ -878,14 +883,10 @@ def main():
if args.image_log_scheduler == 'DDIMScheduler':
print('using DDIMScheduler scheduler')
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
scheduler = DDIMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
else:
print('using PNDMScheduler scheduler')
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,