This commit is contained in:
Felix Marty 2024-06-13 07:41:46 +00:00
parent 8f1de30b0f
commit 31b8cc4386
1 changed files with 13 additions and 1 deletions

View File

@ -175,6 +175,7 @@ class FlashGPT2Attention(torch.nn.Module):
prefix: str,
config,
weights,
layer_idx
):
super().__init__()
self.num_heads = config.num_attention_heads
@ -189,6 +190,7 @@ class FlashGPT2Attention(torch.nn.Module):
f"and `num_shards`: {weights.process_group.size()}"
)
self.num_heads = self.num_heads // weights.process_group.size()
self.layer_idx = layer_idx
self.query_key_value = load_qkv(
config,
@ -218,6 +220,7 @@ class FlashGPT2Attention(torch.nn.Module):
slots,
input_lengths,
max_s,
step
):
query, key, value = self.query_key_value(hidden_states).split(
self.head_size * self.num_heads, dim=1
@ -226,6 +229,11 @@ class FlashGPT2Attention(torch.nn.Module):
key = key.view(-1, self.num_heads, self.head_size)
value = value.view(-1, self.num_heads, self.head_size)
if self.layer_idx < 5:
torch.save(query, f"query_step{step}_layer{self.layer_idx}.pt")
torch.save(key, f"key_step{step}_layer{self.layer_idx}.pt")
torch.save(value, f"value_step{step}_layer{self.layer_idx}.pt")
reshape_and_cache(key, value, kv_cache[0], kv_cache[1], slots)
# output tensor
@ -257,6 +265,9 @@ class FlashGPT2Attention(torch.nn.Module):
max_s,
)
if self.layer_idx < 5:
torch.save(attn_output, f"flash_attn_out_step{step}_layer{self.layer_idx}.pt")
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
@ -301,7 +312,7 @@ class FlashGPT2Layer(nn.Module):
def __init__(self, prefix, config, weights, layer_idx):
super().__init__()
self.self_attn = FlashGPT2Attention(
prefix=f"{prefix}.attn", config=config, weights=weights
prefix=f"{prefix}.attn", config=config, weights=weights, layer_idx=layer_idx
)
self.mlp = GPT2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -342,6 +353,7 @@ class FlashGPT2Layer(nn.Module):
slots,
input_lengths,
max_s,
step
)
if self.layer_idx < 5: