This commit is contained in:
Felix Marty 2024-06-13 07:48:18 +00:00
parent 31b8cc4386
commit 64182534b6
1 changed files with 5 additions and 6 deletions

View File

@ -222,18 +222,17 @@ class FlashGPT2Attention(torch.nn.Module):
max_s,
step
):
query, key, value = self.query_key_value(hidden_states).split(
qkv = self.query_key_value(hidden_states)
if self.layer_idx < 5:
torch.save(qkv, f"qkv_step{step}_layer{self.layer_idx}.pt")
query, key, value = qkv.split(
self.head_size * self.num_heads, dim=1
)
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
if self.layer_idx < 5:
torch.save(query, f"query_step{step}_layer{self.layer_idx}.pt")
torch.save(key, f"key_step{step}_layer{self.layer_idx}.pt")
torch.save(value, f"value_step{step}_layer{self.layer_idx}.pt")
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor