use ddp for everything

This commit is contained in:
Anthony Mercurio 2022-11-30 10:54:30 -07:00 committed by GitHub
parent b0cec788be
commit 4572617ff9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 5 deletions

View File

@ -743,8 +743,7 @@ def main():
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
exit(0) exit(0)
dist_unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], gradient_as_bucket_view=True) unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
unet = dist_unet.module
# create ema # create ema
if args.use_ema: if args.use_ema:
@ -776,7 +775,7 @@ def main():
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline(
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae, vae=vae,
unet=unet, unet=unet.module,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token), scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
@ -801,6 +800,7 @@ def main():
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
continue continue
b_start = time.perf_counter() b_start = time.perf_counter()
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * 0.18215
@ -825,7 +825,7 @@ def main():
# Predict the noise residual and compute loss # Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16): with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample noise_pred = unet.module(noisy_latents, timesteps, encoder_hidden_states).sample
if noise_scheduler.config.prediction_type == "epsilon": if noise_scheduler.config.prediction_type == "epsilon":
target = noise target = noise
@ -894,7 +894,7 @@ def main():
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline(
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae, vae=vae,
unet=unet, unet=unet.module,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=scheduler, scheduler=scheduler,
safety_checker=None, # disable safety checker to save memory safety_checker=None, # disable safety checker to save memory