diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index b29fe02c..f9e43e4d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -621,6 +621,12 @@ class CrossAttention(nn.Module): key = self.reshape_heads_to_batch_dim(key) value = self.reshape_heads_to_batch_dim(value) + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + # attention, what we cannot get enough of if self._use_memory_efficient_attention_xformers: hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) @@ -630,7 +636,7 @@ class CrossAttention(nn.Module): if self._slice_size is None or query.shape[0] // self._slice_size == 1: hidden_states = self._attention(query, key, value, attention_mask) else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) # linear proj hidden_states = self.to_out[0](hidden_states) @@ -653,11 +659,6 @@ class CrossAttention(nn.Module): ) if attention_mask is not None: - if attention_mask.shape != attention_scores.shape: - target_length = query.shape[1] - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) - attention_scores = attention_scores + attention_mask if self.upcast_softmax: @@ -675,7 +676,7 @@ class CrossAttention(nn.Module): hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states - def _sliced_attention(self, query, key, value, sequence_length, dim): + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype @@ -699,6 +700,13 @@ class CrossAttention(nn.Module): beta=0, alpha=self.scale, ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + attn_slice = attn_slice.softmax(dim=-1) # cast back to the original dtype @@ -716,7 +724,7 @@ class CrossAttention(nn.Module): query = query.contiguous() key = key.contiguous() value = value.contiguous() - hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 1ee6f655..a3a9d39b 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -564,23 +564,6 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index 0bf6cfd5..c83a347f 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -1250,23 +1250,6 @@ class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def set_attention_slice(self, slice_size): - head_dims = self.attn_num_head_channels - head_dims = [head_dims] if isinstance(head_dims, int) else head_dims - if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims): - raise ValueError( - f"Make sure slice_size {slice_size} is a common divisor of " - f"the number of heads used in cross_attention: {head_dims}" - ) - if slice_size is not None and slice_size > min(head_dims): - raise ValueError( - f"slice_size {slice_size} has to be smaller or equal to " - f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}" - ) - - for attn in self.attentions: - attn._set_attention_slice(slice_size) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]):