debug
This commit is contained in:
parent
31b8cc4386
commit
64182534b6
|
@ -222,18 +222,17 @@ class FlashGPT2Attention(torch.nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
step
|
step
|
||||||
):
|
):
|
||||||
query, key, value = self.query_key_value(hidden_states).split(
|
qkv = self.query_key_value(hidden_states)
|
||||||
|
if self.layer_idx < 5:
|
||||||
|
torch.save(qkv, f"qkv_step{step}_layer{self.layer_idx}.pt")
|
||||||
|
|
||||||
|
query, key, value = qkv.split(
|
||||||
self.head_size * self.num_heads, dim=1
|
self.head_size * self.num_heads, dim=1
|
||||||
)
|
)
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
key = key.view(-1, self.num_heads, self.head_size)
|
key = key.view(-1, self.num_heads, self.head_size)
|
||||||
value = value.view(-1, self.num_heads, self.head_size)
|
value = value.view(-1, self.num_heads, self.head_size)
|
||||||
|
|
||||||
if self.layer_idx < 5:
|
|
||||||
torch.save(query, f"query_step{step}_layer{self.layer_idx}.pt")
|
|
||||||
torch.save(key, f"key_step{step}_layer{self.layer_idx}.pt")
|
|
||||||
torch.save(value, f"value_step{step}_layer{self.layer_idx}.pt")
|
|
||||||
|
|
||||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
# output tensor
|
# output tensor
|
||||||
|
|
Loading…
Reference in New Issue