diff --git a/diffusers_trainer.py b/diffusers_trainer.py index b3f6a2d..444941a 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -66,6 +66,7 @@ parser.add_argument('--project_id', type=str, default='diffusers', help='Project parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision') parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.') parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') +parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') args = parser.parse_args() def setup(): @@ -602,7 +603,11 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0] + encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True) + if args.clip_penultimate: + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2]) + else: + encoder_hidden_states = encoder_hidden_states.last_hidden_state # Predict the noise residual and compute loss with torch.autocast('cuda', enabled=args.fp16):