Add CLIP penultimate layer for diffusers trainer
This commit is contained in:
parent
1a94c70736
commit
6b73cfe448
|
@ -66,6 +66,7 @@ parser.add_argument('--project_id', type=str, default='diffusers', help='Project
|
||||||
parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision')
|
parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision')
|
||||||
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
||||||
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
||||||
|
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
|
@ -602,7 +603,11 @@ def main():
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
encoder_hidden_states = text_encoder(batch["input_ids"].to(device))[0]
|
encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
|
||||||
|
if args.clip_penultimate:
|
||||||
|
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
|
||||||
|
else:
|
||||||
|
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||||
|
|
||||||
# Predict the noise residual and compute loss
|
# Predict the noise residual and compute loss
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
|
|
Loading…
Reference in New Issue