This commit is contained in:
cafeai 2022-12-09 12:41:50 +09:00
parent dc2dc363dd
commit f792752abc
1 changed files with 6 additions and 6 deletions

View File

@ -60,12 +60,6 @@ def modCAForward(self, hidden_states, context=None, mask=None):
key = self.to_k(context)
value = self.to_v(context)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
#added-------------------------
if enable_tome and is_self_attention and hidden_states.shape[1] >= 2048:
with torch.no_grad():
@ -79,6 +73,12 @@ def modCAForward(self, hidden_states, context=None, mask=None):
value, _ = merge_wavg(merge, value, None)
#------------------------------
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of