some cleaning
This commit is contained in:
parent
3760102077
commit
02ac45131f
|
@ -111,7 +111,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
layer_idx,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
@ -144,7 +143,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||
self.index = index
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
|
@ -165,8 +163,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
self.step = 0
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
@ -198,18 +194,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
if self.layer_idx < 4:
|
||||
torch.save(query, f"query_states_step{self.step}_layer{self.layer_idx}.pt")
|
||||
if cu_seqlen_prefill is not None:
|
||||
torch.save(
|
||||
torch.select(kv, dim=1, index=0),
|
||||
f"key_states_step{self.step}_layer{self.layer_idx}.pt",
|
||||
)
|
||||
torch.save(
|
||||
torch.select(kv, dim=1, index=1),
|
||||
f"value_states_step{self.step}_layer{self.layer_idx}.pt",
|
||||
)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
|
@ -236,14 +220,9 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
max_s,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.num_heads * self.head_size)
|
||||
if self.layer_idx < 4:
|
||||
torch.save(
|
||||
attn_output, f"attn_output_step{self.step}_layer{self.layer_idx}.pt"
|
||||
)
|
||||
|
||||
self.step += 1
|
||||
return self.o_proj(attn_output, adapter_data)
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
@ -342,14 +321,13 @@ class LlamaMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, index, prefix, config, weights, layer_idx):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
index=index,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||
|
@ -422,7 +400,6 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_idx=layer_id,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
|
|
@ -1149,23 +1149,6 @@ class FlashCausalLM(Model):
|
|||
cuda_graph = None
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logger.info(f"input_ids {input_ids} {input_ids.shape}")
|
||||
logger.info(f"position_ids {position_ids} {position_ids.shape}")
|
||||
logger.info(
|
||||
f"cu_seqlen_prefill {cu_seqlen_prefill} {cu_seqlen_prefill.shape if cu_seqlen_prefill is not None else 'NONE'}"
|
||||
)
|
||||
logger.info(
|
||||
f"kv_cache {type(kv_cache)}, len={len(kv_cache)}, {len(kv_cache[0])}, shape={kv_cache[0][0].shape}"
|
||||
)
|
||||
logger.info(
|
||||
f"block_tables {type(block_tables)} {block_tables.shape} {block_tables}"
|
||||
)
|
||||
logger.info(f"slots {type(slots)} {slots.shape} {slots}")
|
||||
logger.info(f"input_lengths {input_lengths}")
|
||||
logger.info(f"max_s {max_s}")
|
||||
logger.info(f"prefill_cache_indices {batch.prefill_cache_indices}")
|
||||
logger.info(f"lm_head_indices {lm_head_indices}")
|
||||
logger.info(f"adapter_data {adapter_data}")
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
|
|
Loading…
Reference in New Issue