correct example

This commit is contained in:
Patrick von Platen 2023-03-08 20:14:19 +01:00
parent 00132de359
commit cbbad0af69
1 changed files with 0 additions and 3 deletions

View File

@ -479,7 +479,6 @@ def main():
weight_dtype = torch.bfloat16
if args.use_peft:
from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict
UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
@ -496,7 +495,6 @@ def main():
vae.requires_grad_(False)
if args.train_text_encoder:
config = LoraConfig(
r=args.lora_text_encoder_r,
lora_alpha=args.lora_text_encoder_alpha,
@ -506,7 +504,6 @@ def main():
)
text_encoder = LoraModel(config, text_encoder)
else:
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)