fp32 Update

This commit is contained in:
cafeai 2022-12-03 19:42:50 +09:00
parent 18ff256be5
commit 3cefb57fc6
1 changed files with 1 additions and 1 deletions

View File

@ -751,7 +751,7 @@ def main():
# move models to device
vae = vae.to(device, dtype=weight_dtype)
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=weight_dtype)
text_encoder = text_encoder.to(device, dtype=weight_dtype if not args.train_text_encoder else torch.float32)
unet = torch.nn.parallel.DistributedDataParallel(
unet,