some cleaning
This commit is contained in:
parent
3760102077
commit
02ac45131f
|
@ -111,7 +111,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
weights,
|
weights,
|
||||||
layer_idx,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
|
@ -144,7 +143,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||||
self.index = index
|
self.index = index
|
||||||
self.layer_idx = layer_idx
|
|
||||||
|
|
||||||
o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
|
@ -165,8 +163,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
).repeat_interleave(self.num_groups)
|
).repeat_interleave(self.num_groups)
|
||||||
|
|
||||||
self.step = 0
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -198,18 +194,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
# output tensor
|
# output tensor
|
||||||
attn_output = torch.empty_like(query)
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
if self.layer_idx < 4:
|
|
||||||
torch.save(query, f"query_states_step{self.step}_layer{self.layer_idx}.pt")
|
|
||||||
if cu_seqlen_prefill is not None:
|
|
||||||
torch.save(
|
|
||||||
torch.select(kv, dim=1, index=0),
|
|
||||||
f"key_states_step{self.step}_layer{self.layer_idx}.pt",
|
|
||||||
)
|
|
||||||
torch.save(
|
|
||||||
torch.select(kv, dim=1, index=1),
|
|
||||||
f"value_states_step{self.step}_layer{self.layer_idx}.pt",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
|
@ -236,15 +220,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(-1, self.num_heads * self.head_size)
|
return self.o_proj(
|
||||||
if self.layer_idx < 4:
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
torch.save(
|
|
||||||
attn_output, f"attn_output_step{self.step}_layer{self.layer_idx}.pt"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.step += 1
|
|
||||||
return self.o_proj(attn_output, adapter_data)
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaMLP(nn.Module):
|
class LlamaMLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights, index):
|
def __init__(self, prefix, config, weights, index):
|
||||||
|
@ -342,14 +321,13 @@ class LlamaMLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaLayer(nn.Module):
|
class FlashLlamaLayer(nn.Module):
|
||||||
def __init__(self, index, prefix, config, weights, layer_idx):
|
def __init__(self, index, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashLlamaAttention(
|
self.self_attn = FlashLlamaAttention(
|
||||||
index=index,
|
index=index,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
layer_idx=layer_idx,
|
|
||||||
)
|
)
|
||||||
self.mlp = LlamaMLP(
|
self.mlp = LlamaMLP(
|
||||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||||
|
@ -422,7 +400,6 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
layer_idx=layer_id,
|
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
|
|
@ -1149,23 +1149,6 @@ class FlashCausalLM(Model):
|
||||||
cuda_graph = None
|
cuda_graph = None
|
||||||
|
|
||||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||||
logger.info(f"input_ids {input_ids} {input_ids.shape}")
|
|
||||||
logger.info(f"position_ids {position_ids} {position_ids.shape}")
|
|
||||||
logger.info(
|
|
||||||
f"cu_seqlen_prefill {cu_seqlen_prefill} {cu_seqlen_prefill.shape if cu_seqlen_prefill is not None else 'NONE'}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"kv_cache {type(kv_cache)}, len={len(kv_cache)}, {len(kv_cache[0])}, shape={kv_cache[0][0].shape}"
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"block_tables {type(block_tables)} {block_tables.shape} {block_tables}"
|
|
||||||
)
|
|
||||||
logger.info(f"slots {type(slots)} {slots.shape} {slots}")
|
|
||||||
logger.info(f"input_lengths {input_lengths}")
|
|
||||||
logger.info(f"max_s {max_s}")
|
|
||||||
logger.info(f"prefill_cache_indices {batch.prefill_cache_indices}")
|
|
||||||
logger.info(f"lm_head_indices {lm_head_indices}")
|
|
||||||
logger.info(f"adapter_data {adapter_data}")
|
|
||||||
logits, speculative_logits = self.model.forward(
|
logits, speculative_logits = self.model.forward(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|
Loading…
Reference in New Issue