Cleanup
This commit is contained in:
parent
ee281badcd
commit
981c6ca41a
|
@ -867,13 +867,8 @@ def main():
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the embedding for conditioning
|
||||||
#encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
|
|
||||||
encoder_hidden_states = batch['input_ids']
|
encoder_hidden_states = batch['input_ids']
|
||||||
#if args.clip_penultimate:
|
|
||||||
# encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
|
|
||||||
#else:
|
|
||||||
# encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
|
||||||
|
|
||||||
# Predict the noise residual and compute loss
|
# Predict the noise residual and compute loss
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
|
|
Loading…
Reference in New Issue