[LoRA] Freezing the model weights (#2245)
* [LoRA] Freezing the model weights Freeze the model weights since we don't need to calculate grads for them. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Apply suggestions from code review --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
62a15cec6e
commit
1be7df0205
|
@ -415,6 +415,11 @@ def main():
|
|||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
# freeze parameters of models to save more memory
|
||||
unet.requires_grad_(False)
|
||||
vae.requires_grad_(False)
|
||||
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
|
|
Loading…
Reference in New Issue