feat: support if vlm models
This commit is contained in:
parent
a563a93113
commit
91f407226d
|
@ -41,6 +41,8 @@ class LoraLinear(nn.Module):
|
|||
start_idx: int,
|
||||
end_idx: int,
|
||||
) -> torch.Tensor:
|
||||
if adapter_data is None:
|
||||
return result
|
||||
data = adapter_data.data.get(layer_type)
|
||||
data: Optional["BatchLoraWeights"] = (
|
||||
data.get(LORA) if data is not None else None
|
||||
|
|
|
@ -108,7 +108,17 @@ class FlashLlama(FlashCausalLM):
|
|||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
for i, layer in enumerate(self.model.model.layers):
|
||||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "language_model"):
|
||||
_model = self.model.language_model
|
||||
elif hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
||||
for i, layer in enumerate(_model.model.layers):
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
|
@ -139,7 +149,7 @@ class FlashLlama(FlashCausalLM):
|
|||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
|
@ -151,7 +161,7 @@ class FlashLlama(FlashCausalLM):
|
|||
return ["q_proj", "v_proj"]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
|
||||
return 1 if layer_type == "lm_head" else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
||||
|
|
|
@ -119,7 +119,17 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
for i, layer in enumerate(self.model.model.layers):
|
||||
|
||||
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
|
||||
# that have a language_model inside of the larger model.
|
||||
if hasattr(self.model, "language_model"):
|
||||
_model = self.model.language_model
|
||||
elif hasattr(self.model, "text_model"):
|
||||
_model = self.model.text_model
|
||||
else:
|
||||
_model = self.model
|
||||
|
||||
for i, layer in enumerate(_model.model.layers):
|
||||
layer_weights[(i, "q_proj")] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
|
@ -150,7 +160,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head)
|
||||
layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
|
|
|
@ -453,6 +453,7 @@ class Mamba(Model):
|
|||
model = MambaModel(config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(Mamba, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
|
|
|
@ -219,7 +219,9 @@ class VlmCausalLM(BaseFlashMistral):
|
|||
return VlmCausalLMBatch
|
||||
|
||||
def forward(
|
||||
self, batch: VlmCausalLMBatch
|
||||
self,
|
||||
batch: VlmCausalLMBatch,
|
||||
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
# Model Forward
|
||||
if batch.speculative_ids is not None:
|
||||
|
|
Loading…
Reference in New Issue