[Versatile] fix attention mask (#1763)
This commit is contained in:
parent
c7b4acfb37
commit
b267d28566
|
@ -959,6 +959,7 @@ class DualTransformer2DModel(nn.Module):
|
|||
|
||||
encoded_states = []
|
||||
tokens_start = 0
|
||||
# attention_mask is not used yet
|
||||
for i in range(2):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||
|
@ -967,7 +968,6 @@ class DualTransformer2DModel(nn.Module):
|
|||
input_states,
|
||||
encoder_hidden_states=condition_state,
|
||||
timestep=timestep,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
|
|
Loading…
Reference in New Issue