[LoRA] Freezing the model weights (#2245)

* [LoRA] Freezing the model weights

Freeze the model weights since we don't need to calculate grads for them.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Apply suggestions from code review

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
erkams 2023-02-09 11:45:11 +01:00 committed by GitHub
parent 62a15cec6e
commit 1be7df0205
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 1 deletions

View File

@ -415,7 +415,12 @@ def main():
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32