fix: add adapter_data param and avoid missing layers

This commit is contained in:
drbh 2024-06-07 03:03:15 +00:00
parent 91f407226d
commit b1169273fd
5 changed files with 19 additions and 12 deletions

View File

@ -672,6 +672,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,

View File

@ -488,6 +488,7 @@ class FlashSantacoderForCausalLM(nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = self.transformer(
input_ids,

View File

@ -525,6 +525,7 @@ class FlashStarcoder2ForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
true_max_s = max_s
if prefill_cache_indices is not None:

View File

@ -147,18 +147,21 @@ class BaseFlashMistral(FlashCausalLM):
layer.self_attn.o_proj,
)
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
# TODO: this is a hack to avoid the gate_proj for
# FlashStarcoder2 that doesnt have these layers
if hasattr(layer.mlp, "gate_up_proj"):
layer_weights[(i, "gate_proj")] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "up_proj")] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, "down_proj")] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights

View File

@ -634,6 +634,7 @@ class IdeficsCausalLM(Model):
tokenizer.add_special_tokens({"pad_token": "<unk>"})
super(IdeficsCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,