diff --git a/examples/research_projects/lora/train_text_to_image_lora.py b/examples/research_projects/lora/train_text_to_image_lora.py index a1ea73f0..2d0f807b 100644 --- a/examples/research_projects/lora/train_text_to_image_lora.py +++ b/examples/research_projects/lora/train_text_to_image_lora.py @@ -479,7 +479,6 @@ def main(): weight_dtype = torch.bfloat16 if args.use_peft: - from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] @@ -496,7 +495,6 @@ def main(): vae.requires_grad_(False) if args.train_text_encoder: - config = LoraConfig( r=args.lora_text_encoder_r, lora_alpha=args.lora_text_encoder_alpha, @@ -506,7 +504,6 @@ def main(): ) text_encoder = LoraModel(config, text_encoder) else: - # freeze parameters of models to save more memory unet.requires_grad_(False) vae.requires_grad_(False)