Update Save Checkpoint
This commit is contained in:
parent
34715bcc97
commit
bf264d0ff0
|
@ -866,7 +866,7 @@ def main():
|
|||
ema_unet.store(unet.parameters())
|
||||
ema_unet.copy_to(unet.parameters())
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
text_encoder=text_encoder if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module,
|
||||
vae=vae,
|
||||
unet=unet.module,
|
||||
tokenizer=tokenizer,
|
||||
|
|
Loading…
Reference in New Issue