diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index e9454a46..0547bb4a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -290,11 +290,19 @@ class AttentionBlock(nn.Module): self.rescale_output_factor = rescale_output_factor self.proj_attn = nn.Linear(channels, channels, 1) - def transpose_for_scores(self, projection: torch.Tensor) -> torch.Tensor: - new_projection_shape = projection.size()[:-1] + (self.num_heads, -1) - # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D) - new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3) - return new_projection + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor def forward(self, hidden_states): residual = hidden_states @@ -312,50 +320,28 @@ class AttentionBlock(nn.Module): scale = 1 / math.sqrt(self.channels / self.num_heads) - # get scores - if self.num_heads > 1: - query_states = self.transpose_for_scores(query_proj) - key_states = self.transpose_for_scores(key_proj) - value_states = self.transpose_for_scores(value_proj) - - # TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors? - # or reformulate this into a 3D problem? - # TODO: measure whether on MPS device it would be faster to do this matmul via einsum - # as some matmuls can be 1.94x slower than an equivalent einsum on MPS - # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0 - attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale - else: - query_states, key_states, value_states = query_proj, key_proj, value_proj - - attention_scores = torch.baddbmm( - torch.empty( - query_states.shape[0], - query_states.shape[1], - key_states.shape[1], - dtype=query_states.dtype, - device=query_states.device, - ), - query_states, - key_states.transpose(-1, -2), - beta=0, - alpha=scale, - ) + query_proj = self.reshape_heads_to_batch_dim(query_proj) + key_proj = self.reshape_heads_to_batch_dim(key_proj) + value_proj = self.reshape_heads_to_batch_dim(value_proj) + attention_scores = torch.baddbmm( + torch.empty( + query_proj.shape[0], + query_proj.shape[1], + key_proj.shape[1], + dtype=query_proj.dtype, + device=query_proj.device, + ), + query_proj, + key_proj.transpose(-1, -2), + beta=0, + alpha=scale, + ) attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + hidden_states = torch.bmm(attention_probs, value_proj) - # compute attention output - if self.num_heads > 1: - # TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors? - # or reformulate this into a 3D problem? - # TODO: measure whether on MPS device it would be faster to do this matmul via einsum - # as some matmuls can be 1.94x slower than an equivalent einsum on MPS - # https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0 - hidden_states = torch.matmul(attention_probs, value_states) - hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous() - new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,) - hidden_states = hidden_states.view(new_hidden_states_shape) - else: - hidden_states = torch.bmm(attention_probs, value_states) + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) # compute next hidden_states hidden_states = self.proj_attn(hidden_states)