use ddp for everything
This commit is contained in:
parent
b0cec788be
commit
4572617ff9
|
@ -743,8 +743,7 @@ def main():
|
|||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
||||
exit(0)
|
||||
|
||||
dist_unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], gradient_as_bucket_view=True)
|
||||
unet = dist_unet.module
|
||||
unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
||||
|
||||
# create ema
|
||||
if args.use_ema:
|
||||
|
@ -776,7 +775,7 @@ def main():
|
|||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
unet=unet.module,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token),
|
||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||
|
@ -801,6 +800,7 @@ def main():
|
|||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
continue
|
||||
|
||||
b_start = time.perf_counter()
|
||||
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
|
@ -825,7 +825,7 @@ def main():
|
|||
|
||||
# Predict the noise residual and compute loss
|
||||
with torch.autocast('cuda', enabled=args.fp16):
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
noise_pred = unet.module(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
|
@ -894,7 +894,7 @@ def main():
|
|||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
unet=unet.module,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # disable safety checker to save memory
|
||||
|
|
Loading…
Reference in New Issue