[UnCLIPPipeline] fix num_images_per_prompt (#1762)

duplicate maks for num_images_per_prompt
This commit is contained in:
Suraj Patil 2022-12-19 14:32:46 +01:00 committed by GitHub
parent 32a5d70c42
commit be38b2d711
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -143,6 +143,7 @@ class UnCLIPPipeline(DiffusionPipeline):
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
uncond_tokens = [""] * batch_size
@ -172,6 +173,7 @@ class UnCLIPPipeline(DiffusionPipeline):
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
uncond_text_mask = uncond_text_mask.repeat(1, num_images_per_prompt)
# done duplicates