Hotfixing after refactor.
This commit is contained in:
parent
fb2f74e2b9
commit
853d4eb9cf
|
@ -355,7 +355,7 @@ class Block(nn.Module):
|
|||
self.ln_2 = FastLayerNorm.load(
|
||||
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
self.attn = FlashMQAttention(
|
||||
self.self_attn = FlashMQAttention(
|
||||
prefix=f"{prefix}.attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
|
@ -378,7 +378,7 @@ class Block(nn.Module):
|
|||
max_s,
|
||||
):
|
||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||
hidden_states = self.attn(
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states,
|
||||
cu_seqlen_prefill,
|
||||
kv_cache,
|
||||
|
@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
reduce=False,
|
||||
)
|
||||
|
||||
self.h = nn.ModuleList(
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
layer_id,
|
||||
|
@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module):
|
|||
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
|
||||
)
|
||||
|
||||
self.head_size = self.h[0].attn.head_size
|
||||
self.num_heads = self.h[0].attn.num_heads
|
||||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module):
|
|||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.h):
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
|
@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.transpose = config.architectures[0].startswith("GPT2")
|
||||
self.transformer = FlashSantacoderModel(config, weights)
|
||||
self.model = FlashSantacoderModel(config, weights)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config, prefix="transformer.wte", weights=weights
|
||||
)
|
||||
|
@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.transformer(
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
position_ids,
|
||||
cu_seqlen_prefill,
|
||||
|
|
|
@ -60,7 +60,7 @@ class Model(ABC):
|
|||
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||
LayerAdapterWeights
|
||||
)
|
||||
self.target_to_layer = self.adapter_target_to_layer()
|
||||
self.target_to_layer = None
|
||||
self.loaded_adapters = set()
|
||||
self.static_adapter_id = adapter_id
|
||||
|
||||
|
@ -187,6 +187,8 @@ class Model(ABC):
|
|||
into model. Otherwise, the adapter weights are applied during the forward
|
||||
pass and stored separately from the base model parameters.
|
||||
"""
|
||||
if self.target_to_layer is None:
|
||||
self.target_to_layer = self.adapter_target_to_layer()
|
||||
if adapter_index in self.loaded_adapters:
|
||||
# Adapter already loaded
|
||||
return
|
||||
|
|
Loading…
Reference in New Issue