Add attention mask to uclip (#1756)
* Remove bogus file * [Unclip] Add efficient attention * [Unclip] Add efficient attention
This commit is contained in:
parent
dc7cd893fd
commit
429e5449c1
|
@ -621,6 +621,12 @@ class CrossAttention(nn.Module):
|
|||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.shape[-1] != query.shape[1]:
|
||||
target_length = query.shape[1]
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
||||
|
@ -630,7 +636,7 @@ class CrossAttention(nn.Module):
|
|||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value, attention_mask)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
|
@ -653,11 +659,6 @@ class CrossAttention(nn.Module):
|
|||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.shape != attention_scores.shape:
|
||||
target_length = query.shape[1]
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
||||
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
if self.upcast_softmax:
|
||||
|
@ -675,7 +676,7 @@ class CrossAttention(nn.Module):
|
|||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _sliced_attention(self, query, key, value, sequence_length, dim):
|
||||
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
|
||||
batch_size_attention = query.shape[0]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
|
||||
|
@ -699,6 +700,13 @@ class CrossAttention(nn.Module):
|
|||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
|
||||
|
||||
if self.upcast_softmax:
|
||||
attn_slice = attn_slice.float()
|
||||
|
||||
attn_slice = attn_slice.softmax(dim=-1)
|
||||
|
||||
# cast back to the original dtype
|
||||
|
@ -716,7 +724,7 @@ class CrossAttention(nn.Module):
|
|||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None)
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
|
|
@ -564,23 +564,6 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
|||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
|
|
|
@ -1250,23 +1250,6 @@ class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module):
|
|||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
|
|
Loading…
Reference in New Issue