Hotfixing after refactor.
This commit is contained in:
parent
fb2f74e2b9
commit
853d4eb9cf
|
@ -355,7 +355,7 @@ class Block(nn.Module):
|
||||||
self.ln_2 = FastLayerNorm.load(
|
self.ln_2 = FastLayerNorm.load(
|
||||||
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
|
prefix=f"{prefix}.ln_2", weights=weights, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
self.attn = FlashMQAttention(
|
self.self_attn = FlashMQAttention(
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
@ -378,7 +378,7 @@ class Block(nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
hidden_states, residual = self.ln_1(hidden_states, residual)
|
hidden_states, residual = self.ln_1(hidden_states, residual)
|
||||||
hidden_states = self.attn(
|
hidden_states = self.self_attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
kv_cache,
|
kv_cache,
|
||||||
|
@ -412,7 +412,7 @@ class FlashSantacoderModel(nn.Module):
|
||||||
reduce=False,
|
reduce=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.h = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Block(
|
Block(
|
||||||
layer_id,
|
layer_id,
|
||||||
|
@ -426,8 +426,8 @@ class FlashSantacoderModel(nn.Module):
|
||||||
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
|
prefix="transformer.ln_f", weights=weights, eps=config.layer_norm_epsilon
|
||||||
)
|
)
|
||||||
|
|
||||||
self.head_size = self.h[0].attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.h[0].attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -446,7 +446,7 @@ class FlashSantacoderModel(nn.Module):
|
||||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||||
|
|
||||||
residual = None
|
residual = None
|
||||||
for i, layer in enumerate(self.h):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
residual,
|
||||||
|
@ -467,7 +467,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config.transpose = config.architectures[0].startswith("GPT2")
|
config.transpose = config.architectures[0].startswith("GPT2")
|
||||||
self.transformer = FlashSantacoderModel(config, weights)
|
self.model = FlashSantacoderModel(config, weights)
|
||||||
self.lm_head = SpeculativeHead.load(
|
self.lm_head = SpeculativeHead.load(
|
||||||
config, prefix="transformer.wte", weights=weights
|
config, prefix="transformer.wte", weights=weights
|
||||||
)
|
)
|
||||||
|
@ -486,7 +486,7 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
lm_head_indices: Optional[torch.Tensor] = None,
|
lm_head_indices: Optional[torch.Tensor] = None,
|
||||||
adapter_data: Optional[torch.Tensor] = None,
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.transformer(
|
hidden_states = self.model(
|
||||||
input_ids,
|
input_ids,
|
||||||
position_ids,
|
position_ids,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
|
|
|
@ -60,7 +60,7 @@ class Model(ABC):
|
||||||
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
self.layer_to_adapter_weights: Dict[str, LayerAdapterWeights] = defaultdict(
|
||||||
LayerAdapterWeights
|
LayerAdapterWeights
|
||||||
)
|
)
|
||||||
self.target_to_layer = self.adapter_target_to_layer()
|
self.target_to_layer = None
|
||||||
self.loaded_adapters = set()
|
self.loaded_adapters = set()
|
||||||
self.static_adapter_id = adapter_id
|
self.static_adapter_id = adapter_id
|
||||||
|
|
||||||
|
@ -187,6 +187,8 @@ class Model(ABC):
|
||||||
into model. Otherwise, the adapter weights are applied during the forward
|
into model. Otherwise, the adapter weights are applied during the forward
|
||||||
pass and stored separately from the base model parameters.
|
pass and stored separately from the base model parameters.
|
||||||
"""
|
"""
|
||||||
|
if self.target_to_layer is None:
|
||||||
|
self.target_to_layer = self.adapter_target_to_layer()
|
||||||
if adapter_index in self.loaded_adapters:
|
if adapter_index in self.loaded_adapters:
|
||||||
# Adapter already loaded
|
# Adapter already loaded
|
||||||
return
|
return
|
||||||
|
|
Loading…
Reference in New Issue