feat: support if vlm models
This commit is contained in:
parent
a563a93113
commit
91f407226d
|
@ -41,6 +41,8 @@ class LoraLinear(nn.Module):
|
||||||
start_idx: int,
|
start_idx: int,
|
||||||
end_idx: int,
|
end_idx: int,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if adapter_data is None:
|
||||||
|
return result
|
||||||
data = adapter_data.data.get(layer_type)
|
data = adapter_data.data.get(layer_type)
|
||||||
data: Optional["BatchLoraWeights"] = (
|
data: Optional["BatchLoraWeights"] = (
|
||||||
data.get(LORA) if data is not None else None
|
data.get(LORA) if data is not None else None
|
||||||
|
|
|
@ -108,7 +108,17 @@ class FlashLlama(FlashCausalLM):
|
||||||
layer_weights = {}
|
layer_weights = {}
|
||||||
|
|
||||||
prefix = "model.layers"
|
prefix = "model.layers"
|
||||||
for i, layer in enumerate(self.model.model.layers):
|
|
||||||
|
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||||
|
# that have a language_model inside of the larger model.
|
||||||
|
if hasattr(self.model, "language_model"):
|
||||||
|
_model = self.model.language_model
|
||||||
|
elif hasattr(self.model, "text_model"):
|
||||||
|
_model = self.model.text_model
|
||||||
|
else:
|
||||||
|
_model = self.model
|
||||||
|
|
||||||
|
for i, layer in enumerate(_model.model.layers):
|
||||||
layer_weights[(i, "q_proj")] = (
|
layer_weights[(i, "q_proj")] = (
|
||||||
f"{prefix}.{i}.self_attn.q_proj",
|
f"{prefix}.{i}.self_attn.q_proj",
|
||||||
layer.self_attn.query_key_value,
|
layer.self_attn.query_key_value,
|
||||||
|
@ -139,7 +149,7 @@ class FlashLlama(FlashCausalLM):
|
||||||
layer.mlp.down_proj,
|
layer.mlp.down_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
|
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||||
return layer_weights
|
return layer_weights
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -151,7 +161,7 @@ class FlashLlama(FlashCausalLM):
|
||||||
return ["q_proj", "v_proj"]
|
return ["q_proj", "v_proj"]
|
||||||
|
|
||||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||||
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
|
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||||
|
|
||||||
def is_row_parallel(self, layer_type: str) -> bool:
|
def is_row_parallel(self, layer_type: str) -> bool:
|
||||||
return layer_type in ROW_PARALLEL
|
return layer_type in ROW_PARALLEL
|
||||||
|
|
|
@ -119,7 +119,17 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
layer_weights = {}
|
layer_weights = {}
|
||||||
|
|
||||||
prefix = "model.layers"
|
prefix = "model.layers"
|
||||||
for i, layer in enumerate(self.model.model.layers):
|
|
||||||
|
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||||
|
# that have a language_model inside of the larger model.
|
||||||
|
if hasattr(self.model, "language_model"):
|
||||||
|
_model = self.model.language_model
|
||||||
|
elif hasattr(self.model, "text_model"):
|
||||||
|
_model = self.model.text_model
|
||||||
|
else:
|
||||||
|
_model = self.model
|
||||||
|
|
||||||
|
for i, layer in enumerate(_model.model.layers):
|
||||||
layer_weights[(i, "q_proj")] = (
|
layer_weights[(i, "q_proj")] = (
|
||||||
f"{prefix}.{i}.self_attn.q_proj",
|
f"{prefix}.{i}.self_attn.q_proj",
|
||||||
layer.self_attn.query_key_value,
|
layer.self_attn.query_key_value,
|
||||||
|
@ -150,7 +160,7 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
layer.mlp.down_proj,
|
layer.mlp.down_proj,
|
||||||
)
|
)
|
||||||
|
|
||||||
layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head)
|
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||||
return layer_weights
|
return layer_weights
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -453,6 +453,7 @@ class Mamba(Model):
|
||||||
model = MambaModel(config, weights)
|
model = MambaModel(config, weights)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(Mamba, self).__init__(
|
super(Mamba, self).__init__(
|
||||||
|
model_id=model_id,
|
||||||
model=model,
|
model=model,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
requires_padding=True,
|
requires_padding=True,
|
||||||
|
|
|
@ -219,7 +219,9 @@ class VlmCausalLM(BaseFlashMistral):
|
||||||
return VlmCausalLMBatch
|
return VlmCausalLMBatch
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, batch: VlmCausalLMBatch
|
self,
|
||||||
|
batch: VlmCausalLMBatch,
|
||||||
|
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
if batch.speculative_ids is not None:
|
if batch.speculative_ids is not None:
|
||||||
|
|
Loading…
Reference in New Issue