use ddp for everything
This commit is contained in:
parent
b0cec788be
commit
4572617ff9
|
@ -743,8 +743,7 @@ def main():
|
||||||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
dist_unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], gradient_as_bucket_view=True)
|
unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
||||||
unet = dist_unet.module
|
|
||||||
|
|
||||||
# create ema
|
# create ema
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
|
@ -776,7 +775,7 @@ def main():
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet.module,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token),
|
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token),
|
||||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||||
|
@ -801,6 +800,7 @@ def main():
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
b_start = time.perf_counter()
|
b_start = time.perf_counter()
|
||||||
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
|
@ -825,7 +825,7 @@ def main():
|
||||||
|
|
||||||
# Predict the noise residual and compute loss
|
# Predict the noise residual and compute loss
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = unet.module(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
if noise_scheduler.config.prediction_type == "epsilon":
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
target = noise
|
target = noise
|
||||||
|
@ -894,7 +894,7 @@ def main():
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet.module,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
safety_checker=None, # disable safety checker to save memory
|
safety_checker=None, # disable safety checker to save memory
|
||||||
|
|
Loading…
Reference in New Issue