some cleaning

This commit is contained in:
Felix Marty 2024-06-27 13:33:35 +00:00
parent 3760102077
commit 02ac45131f
2 changed files with 4 additions and 44 deletions

View File

@ -111,7 +111,6 @@ class FlashLlamaAttention(torch.nn.Module):
prefix: str, prefix: str,
config, config,
weights, weights,
layer_idx,
): ):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
@ -144,7 +143,6 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index) self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index self.index = index
self.layer_idx = layer_idx
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
@ -165,8 +163,6 @@ class FlashLlamaAttention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups) ).repeat_interleave(self.num_groups)
self.step = 0
def forward( def forward(
self, self,
hidden_states, hidden_states,
@ -198,18 +194,6 @@ class FlashLlamaAttention(torch.nn.Module):
# output tensor # output tensor
attn_output = torch.empty_like(query) attn_output = torch.empty_like(query)
if self.layer_idx < 4:
torch.save(query, f"query_states_step{self.step}_layer{self.layer_idx}.pt")
if cu_seqlen_prefill is not None:
torch.save(
torch.select(kv, dim=1, index=0),
f"key_states_step{self.step}_layer{self.layer_idx}.pt",
)
torch.save(
torch.select(kv, dim=1, index=1),
f"value_states_step{self.step}_layer{self.layer_idx}.pt",
)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
@ -236,15 +220,10 @@ class FlashLlamaAttention(torch.nn.Module):
max_s, max_s,
) )
attn_output = attn_output.view(-1, self.num_heads * self.head_size) return self.o_proj(
if self.layer_idx < 4: attn_output.view(-1, self.num_heads * self.head_size), adapter_data
torch.save(
attn_output, f"attn_output_step{self.step}_layer{self.layer_idx}.pt"
) )
self.step += 1
return self.o_proj(attn_output, adapter_data)
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
def __init__(self, prefix, config, weights, index): def __init__(self, prefix, config, weights, index):
@ -342,14 +321,13 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights, layer_idx): def __init__(self, index, prefix, config, weights):
super().__init__() super().__init__()
self.self_attn = FlashLlamaAttention( self.self_attn = FlashLlamaAttention(
index=index, index=index,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
config=config, config=config,
weights=weights, weights=weights,
layer_idx=layer_idx,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
@ -422,7 +400,6 @@ class FlashLlamaModel(torch.nn.Module):
), ),
config=config, config=config,
weights=weights, weights=weights,
layer_idx=layer_id,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]

View File

@ -1149,23 +1149,6 @@ class FlashCausalLM(Model):
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logger.info(f"input_ids {input_ids} {input_ids.shape}")
logger.info(f"position_ids {position_ids} {position_ids.shape}")
logger.info(
f"cu_seqlen_prefill {cu_seqlen_prefill} {cu_seqlen_prefill.shape if cu_seqlen_prefill is not None else 'NONE'}"
)
logger.info(
f"kv_cache {type(kv_cache)}, len={len(kv_cache)}, {len(kv_cache[0])}, shape={kv_cache[0][0].shape}"
)
logger.info(
f"block_tables {type(block_tables)} {block_tables.shape} {block_tables}"
)
logger.info(f"slots {type(slots)} {slots.shape} {slots}")
logger.info(f"input_lengths {input_lengths}")
logger.info(f"max_s {max_s}")
logger.info(f"prefill_cache_indices {batch.prefill_cache_indices}")
logger.info(f"lm_head_indices {lm_head_indices}")
logger.info(f"adapter_data {adapter_data}")
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,