[Versatile] fix attention mask (#1763)
This commit is contained in:
parent
c7b4acfb37
commit
b267d28566
|
@ -959,6 +959,7 @@ class DualTransformer2DModel(nn.Module):
|
||||||
|
|
||||||
encoded_states = []
|
encoded_states = []
|
||||||
tokens_start = 0
|
tokens_start = 0
|
||||||
|
# attention_mask is not used yet
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
# for each of the two transformers, pass the corresponding condition tokens
|
# for each of the two transformers, pass the corresponding condition tokens
|
||||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||||
|
@ -967,7 +968,6 @@ class DualTransformer2DModel(nn.Module):
|
||||||
input_states,
|
input_states,
|
||||||
encoder_hidden_states=condition_state,
|
encoder_hidden_states=condition_state,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
attention_mask=attention_mask,
|
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
encoded_states.append(encoded_state - input_states)
|
encoded_states.append(encoded_state - input_states)
|
||||||
|
|
Loading…
Reference in New Issue