debug
This commit is contained in:
parent
31b8cc4386
commit
64182534b6
|
@ -222,18 +222,17 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
max_s,
|
||||
step
|
||||
):
|
||||
query, key, value = self.query_key_value(hidden_states).split(
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
if self.layer_idx < 5:
|
||||
torch.save(qkv, f"qkv_step{step}_layer{self.layer_idx}.pt")
|
||||
|
||||
query, key, value = qkv.split(
|
||||
self.head_size * self.num_heads, dim=1
|
||||
)
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_heads, self.head_size)
|
||||
value = value.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
if self.layer_idx < 5:
|
||||
torch.save(query, f"query_step{step}_layer{self.layer_idx}.pt")
|
||||
torch.save(key, f"key_step{step}_layer{self.layer_idx}.pt")
|
||||
torch.save(value, f"value_step{step}_layer{self.layer_idx}.pt")
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
|
|
Loading…
Reference in New Issue