This commit is contained in:
cafeai 2022-12-01 04:32:10 +09:00
parent ee281badcd
commit 981c6ca41a
1 changed files with 1 additions and 6 deletions

View File

@ -867,13 +867,8 @@ def main():
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
#encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
# Get the embedding for conditioning
encoder_hidden_states = batch['input_ids']
#if args.clip_penultimate:
# encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
#else:
# encoder_hidden_states = encoder_hidden_states.last_hidden_state
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):