fp32 Update
This commit is contained in:
parent
18ff256be5
commit
3cefb57fc6
|
@ -751,7 +751,7 @@ def main():
|
|||
# move models to device
|
||||
vae = vae.to(device, dtype=weight_dtype)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype if not args.train_text_encoder else torch.float32)
|
||||
|
||||
unet = torch.nn.parallel.DistributedDataParallel(
|
||||
unet,
|
||||
|
|
Loading…
Reference in New Issue