[UnCLIPPipeline] fix num_images_per_prompt (#1762)
duplicate maks for num_images_per_prompt
This commit is contained in:
parent
32a5d70c42
commit
be38b2d711
|
@ -143,6 +143,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_mask = text_mask.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens = [""] * batch_size
|
||||
|
@ -172,6 +173,7 @@ class UnCLIPPipeline(DiffusionPipeline):
|
|||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
uncond_text_mask = uncond_text_mask.repeat(1, num_images_per_prompt)
|
||||
|
||||
# done duplicates
|
||||
|
||||
|
|
Loading…
Reference in New Issue