test
This commit is contained in:
parent
dc2dc363dd
commit
f792752abc
|
@ -60,12 +60,6 @@ def modCAForward(self, hidden_states, context=None, mask=None):
|
|||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
|
||||
dim = query.shape[-1]
|
||||
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
#added-------------------------
|
||||
if enable_tome and is_self_attention and hidden_states.shape[1] >= 2048:
|
||||
with torch.no_grad():
|
||||
|
@ -79,6 +73,12 @@ def modCAForward(self, hidden_states, context=None, mask=None):
|
|||
value, _ = merge_wavg(merge, value, None)
|
||||
#------------------------------
|
||||
|
||||
dim = query.shape[-1]
|
||||
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
|
|
Loading…
Reference in New Issue