debug
This commit is contained in:
parent
8f1de30b0f
commit
31b8cc4386
|
@ -175,6 +175,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
layer_idx
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
@ -189,6 +190,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
f"and `num_shards`: {weights.process_group.size()}"
|
||||
)
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.query_key_value = load_qkv(
|
||||
config,
|
||||
|
@ -218,6 +220,7 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
step
|
||||
):
|
||||
query, key, value = self.query_key_value(hidden_states).split(
|
||||
self.head_size * self.num_heads, dim=1
|
||||
|
@ -226,6 +229,11 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
key = key.view(-1, self.num_heads, self.head_size)
|
||||
value = value.view(-1, self.num_heads, self.head_size)
|
||||
|
||||
if self.layer_idx < 5:
|
||||
torch.save(query, f"query_step{step}_layer{self.layer_idx}.pt")
|
||||
torch.save(key, f"key_step{step}_layer{self.layer_idx}.pt")
|
||||
torch.save(value, f"value_step{step}_layer{self.layer_idx}.pt")
|
||||
|
||||
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
|
||||
|
||||
# output tensor
|
||||
|
@ -257,6 +265,9 @@ class FlashGPT2Attention(torch.nn.Module):
|
|||
max_s,
|
||||
)
|
||||
|
||||
if self.layer_idx < 5:
|
||||
torch.save(attn_output, f"flash_attn_out_step{step}_layer{self.layer_idx}.pt")
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
|
||||
|
||||
|
@ -301,7 +312,7 @@ class FlashGPT2Layer(nn.Module):
|
|||
def __init__(self, prefix, config, weights, layer_idx):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGPT2Attention(
|
||||
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.attn", config=config, weights=weights, layer_idx=layer_idx
|
||||
)
|
||||
self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
|
@ -342,6 +353,7 @@ class FlashGPT2Layer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
step
|
||||
)
|
||||
|
||||
if self.layer_idx < 5:
|
||||
|
|
Loading…
Reference in New Issue