Fix distributed training
This commit is contained in:
parent
5f0a952eff
commit
8decb0bc7d
|
@ -777,9 +777,7 @@ def main():
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
scheduler=PNDMScheduler(
|
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token),
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
|
||||||
),
|
|
||||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
)
|
)
|
||||||
|
@ -828,10 +826,20 @@ def main():
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
|
target = noise
|
||||||
|
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||||
|
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
# Backprop and all reduce
|
# All-reduce loss, backprop, and update weights
|
||||||
|
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||||
|
loss = loss / world_size
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
@ -849,14 +857,11 @@ def main():
|
||||||
world_images_per_second = rank_images_per_second * world_size
|
world_images_per_second = rank_images_per_second * world_size
|
||||||
samples_seen = global_step * args.batch_size * world_size
|
samples_seen = global_step * args.batch_size * world_size
|
||||||
|
|
||||||
# All reduce loss
|
|
||||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
logs = {
|
logs = {
|
||||||
"train/loss": loss.detach().item() / world_size,
|
"train/loss": loss.detach().item(),
|
||||||
"train/lr": lr_scheduler.get_last_lr()[0],
|
"train/lr": lr_scheduler.get_last_lr()[0],
|
||||||
"train/epoch": epoch,
|
"train/epoch": epoch,
|
||||||
"train/step": global_step,
|
"train/step": global_step,
|
||||||
|
@ -878,14 +883,10 @@ def main():
|
||||||
|
|
||||||
if args.image_log_scheduler == 'DDIMScheduler':
|
if args.image_log_scheduler == 'DDIMScheduler':
|
||||||
print('using DDIMScheduler scheduler')
|
print('using DDIMScheduler scheduler')
|
||||||
scheduler = DDIMScheduler(
|
scheduler = DDIMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
print('using PNDMScheduler scheduler')
|
print('using PNDMScheduler scheduler')
|
||||||
scheduler=PNDMScheduler(
|
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
|
|
Loading…
Reference in New Issue