fp32 Update
This commit is contained in:
parent
18ff256be5
commit
3cefb57fc6
|
@ -751,7 +751,7 @@ def main():
|
||||||
# move models to device
|
# move models to device
|
||||||
vae = vae.to(device, dtype=weight_dtype)
|
vae = vae.to(device, dtype=weight_dtype)
|
||||||
unet = unet.to(device, dtype=torch.float32)
|
unet = unet.to(device, dtype=torch.float32)
|
||||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
text_encoder = text_encoder.to(device, dtype=weight_dtype if not args.train_text_encoder else torch.float32)
|
||||||
|
|
||||||
unet = torch.nn.parallel.DistributedDataParallel(
|
unet = torch.nn.parallel.DistributedDataParallel(
|
||||||
unet,
|
unet,
|
||||||
|
|
Loading…
Reference in New Issue