correct example
This commit is contained in:
parent
00132de359
commit
cbbad0af69
|
@ -479,7 +479,6 @@ def main():
|
|||
weight_dtype = torch.bfloat16
|
||||
|
||||
if args.use_peft:
|
||||
|
||||
from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict
|
||||
|
||||
UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
|
||||
|
@ -496,7 +495,6 @@ def main():
|
|||
|
||||
vae.requires_grad_(False)
|
||||
if args.train_text_encoder:
|
||||
|
||||
config = LoraConfig(
|
||||
r=args.lora_text_encoder_r,
|
||||
lora_alpha=args.lora_text_encoder_alpha,
|
||||
|
@ -506,7 +504,6 @@ def main():
|
|||
)
|
||||
text_encoder = LoraModel(config, text_encoder)
|
||||
else:
|
||||
|
||||
# freeze parameters of models to save more memory
|
||||
unet.requires_grad_(False)
|
||||
vae.requires_grad_(False)
|
||||
|
|
Loading…
Reference in New Issue