Add CLIP penultimate layer for diffusers trainer

This commit is contained in:
Anthony Mercurio 2022-11-02 16:20:46 -07:00 committed by GitHub
parent 1a94c70736
commit 6b73cfe448
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 1 deletions

View File

@ -66,6 +66,7 @@ parser.add_argument('--project_id', type=str, default='diffusers', help='Project
parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision') parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision')
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.') parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
args = parser.parse_args() args = parser.parse_args()
def setup(): def setup():
@ -602,7 +603,11 @@ def main():
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning # Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0] encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
if args.clip_penultimate:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
# Predict the noise residual and compute loss # Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16): with torch.autocast('cuda', enabled=args.fp16):