debug
This commit is contained in:
parent
b3e9a13e27
commit
8f1de30b0f
|
@ -298,7 +298,7 @@ class GPT2MLP(nn.Module):
|
|||
|
||||
|
||||
class FlashGPT2Layer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, layer_idx):
|
||||
super().__init__()
|
||||
self.self_attn = FlashGPT2Attention(
|
||||
prefix=f"{prefix}.attn", config=config, weights=weights
|
||||
|
@ -313,6 +313,7 @@ class FlashGPT2Layer(nn.Module):
|
|||
weights=weights,
|
||||
eps=config.layer_norm_epsilon,
|
||||
)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -324,10 +325,14 @@ class FlashGPT2Layer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
step,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
if self.layer_idx < 5:
|
||||
torch.save(hidden_states, f"hidden_states_bef_attn_step{step}_layer{self.layer_idx}.pt")
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
hidden_states,
|
||||
|
@ -339,6 +344,9 @@ class FlashGPT2Layer(nn.Module):
|
|||
max_s,
|
||||
)
|
||||
|
||||
if self.layer_idx < 5:
|
||||
torch.save(attn_output, f"attn_output_step{step}_layer{self.layer_idx}.pt")
|
||||
|
||||
hidden_states = attn_output + residual
|
||||
residual = hidden_states
|
||||
|
||||
|
@ -346,6 +354,9 @@ class FlashGPT2Layer(nn.Module):
|
|||
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
|
||||
if self.layer_idx < 5:
|
||||
torch.save(mlp_output, f"mlp_output_step{step}_layer{self.layer_idx}.pt")
|
||||
|
||||
return residual + mlp_output, residual
|
||||
|
||||
|
||||
|
@ -364,6 +375,7 @@ class FlashGPT2Model(torch.nn.Module):
|
|||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_idx=layer_id
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -379,6 +391,7 @@ class FlashGPT2Model(torch.nn.Module):
|
|||
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
self.step = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -406,7 +419,9 @@ class FlashGPT2Model(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
self.step,
|
||||
)
|
||||
self.step += 1
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
|
|
Loading…
Reference in New Issue