Fix distributed training
This commit is contained in:
parent
5f0a952eff
commit
8decb0bc7d
|
@ -777,9 +777,7 @@ def main():
|
|||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||
),
|
||||
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
)
|
||||
|
@ -828,10 +826,20 @@ def main():
|
|||
with torch.autocast('cuda', enabled=args.fp16):
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Backprop and all reduce
|
||||
# All-reduce loss, backprop, and update weights
|
||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||
loss = loss / world_size
|
||||
scaler.scale(loss).backward()
|
||||
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
lr_scheduler.step()
|
||||
|
@ -849,14 +857,11 @@ def main():
|
|||
world_images_per_second = rank_images_per_second * world_size
|
||||
samples_seen = global_step * args.batch_size * world_size
|
||||
|
||||
# All reduce loss
|
||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
logs = {
|
||||
"train/loss": loss.detach().item() / world_size,
|
||||
"train/loss": loss.detach().item(),
|
||||
"train/lr": lr_scheduler.get_last_lr()[0],
|
||||
"train/epoch": epoch,
|
||||
"train/step": global_step,
|
||||
|
@ -878,14 +883,10 @@ def main():
|
|||
|
||||
if args.image_log_scheduler == 'DDIMScheduler':
|
||||
print('using DDIMScheduler scheduler')
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
scheduler = DDIMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
|
||||
else:
|
||||
print('using PNDMScheduler scheduler')
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||
)
|
||||
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
|
|
Loading…
Reference in New Issue