This commit is contained in:
Felix Marty 2024-06-13 07:31:11 +00:00
parent b3e9a13e27
commit 8f1de30b0f
1 changed files with 16 additions and 1 deletions

View File

@ -298,7 +298,7 @@ class GPT2MLP(nn.Module):
class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, layer_idx):
super().__init__()
self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights
@ -313,6 +313,7 @@ class FlashGPT2Layer(nn.Module):
weights=weights,
eps=config.layer_norm_epsilon,
)
self.layer_idx = layer_idx
def forward(
self,
@ -324,10 +325,14 @@ class FlashGPT2Layer(nn.Module):
slots,
input_lengths,
max_s,
step,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if self.layer_idx < 5:
torch.save(hidden_states, f"hidden_states_bef_attn_step{step}_layer{self.layer_idx}.pt")
# Self Attention
attn_output = self.self_attn(
hidden_states,
@ -339,6 +344,9 @@ class FlashGPT2Layer(nn.Module):
max_s,
)
if self.layer_idx < 5:
torch.save(attn_output, f"attn_output_step{step}_layer{self.layer_idx}.pt")
hidden_states = attn_output + residual
residual = hidden_states
@ -346,6 +354,9 @@ class FlashGPT2Layer(nn.Module):
mlp_output = self.mlp(hidden_states)
if self.layer_idx < 5:
torch.save(mlp_output, f"mlp_output_step{step}_layer{self.layer_idx}.pt")
return residual + mlp_output, residual
@ -364,6 +375,7 @@ class FlashGPT2Model(torch.nn.Module):
),
config=config,
weights=weights,
layer_idx=layer_id
)
for layer_id in range(config.num_hidden_layers)
]
@ -379,6 +391,7 @@ class FlashGPT2Model(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads
self.step = 0
def forward(
self,
@ -406,7 +419,9 @@ class FlashGPT2Model(torch.nn.Module):
slots,
input_lengths,
max_s,
self.step,
)
self.step += 1
hidden_states = self.norm(hidden_states)