[Versatile] fix attention mask (#1763)

This commit is contained in:
Patrick von Platen 2022-12-19 15:58:39 +01:00 committed by GitHub
parent c7b4acfb37
commit b267d28566
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -959,6 +959,7 @@ class DualTransformer2DModel(nn.Module):
encoded_states = []
tokens_start = 0
# attention_mask is not used yet
for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
@ -967,7 +968,6 @@ class DualTransformer2DModel(nn.Module):
input_states,
encoder_hidden_states=condition_state,
timestep=timestep,
attention_mask=attention_mask,
return_dict=False,
)[0]
encoded_states.append(encoded_state - input_states)