feat: support if vlm models

This commit is contained in:
drbh 2024-06-07 02:21:06 +00:00
parent a563a93113
commit 91f407226d
5 changed files with 31 additions and 6 deletions

View File

@ -41,6 +41,8 @@ class LoraLinear(nn.Module):
start_idx: int, start_idx: int,
end_idx: int, end_idx: int,
) -> torch.Tensor: ) -> torch.Tensor:
if adapter_data is None:
return result
data = adapter_data.data.get(layer_type) data = adapter_data.data.get(layer_type)
data: Optional["BatchLoraWeights"] = ( data: Optional["BatchLoraWeights"] = (
data.get(LORA) if data is not None else None data.get(LORA) if data is not None else None

View File

@ -108,7 +108,17 @@ class FlashLlama(FlashCausalLM):
layer_weights = {} layer_weights = {}
prefix = "model.layers" prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = ( layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj", f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value, layer.self_attn.query_key_value,
@ -139,7 +149,7 @@ class FlashLlama(FlashCausalLM):
layer.mlp.down_proj, layer.mlp.down_proj,
) )
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head) layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights return layer_weights
@property @property
@ -151,7 +161,7 @@ class FlashLlama(FlashCausalLM):
return ["q_proj", "v_proj"] return ["q_proj", "v_proj"]
def get_num_layers_for_type(self, layer_type: str) -> int: def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers) return 1 if layer_type == "lm_head" else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool: def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL return layer_type in ROW_PARALLEL

View File

@ -119,7 +119,17 @@ class BaseFlashMistral(FlashCausalLM):
layer_weights = {} layer_weights = {}
prefix = "model.layers" prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
# This accounts for VLMs (e.g. LlavaNext, Idefics2)
# that have a language_model inside of the larger model.
if hasattr(self.model, "language_model"):
_model = self.model.language_model
elif hasattr(self.model, "text_model"):
_model = self.model.text_model
else:
_model = self.model
for i, layer in enumerate(_model.model.layers):
layer_weights[(i, "q_proj")] = ( layer_weights[(i, "q_proj")] = (
f"{prefix}.{i}.self_attn.q_proj", f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value, layer.self_attn.query_key_value,
@ -150,7 +160,7 @@ class BaseFlashMistral(FlashCausalLM):
layer.mlp.down_proj, layer.mlp.down_proj,
) )
layer_weights[(0, "lm_head")] = ("lm_head", self.model.lm_head) layer_weights[(0, "lm_head")] = ("lm_head", _model.lm_head)
return layer_weights return layer_weights
@property @property

View File

@ -453,6 +453,7 @@ class Mamba(Model):
model = MambaModel(config, weights) model = MambaModel(config, weights)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(Mamba, self).__init__( super(Mamba, self).__init__(
model_id=model_id,
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,

View File

@ -219,7 +219,9 @@ class VlmCausalLM(BaseFlashMistral):
return VlmCausalLMBatch return VlmCausalLMBatch
def forward( def forward(
self, batch: VlmCausalLMBatch self,
batch: VlmCausalLMBatch,
adapter_data: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
# Model Forward # Model Forward
if batch.speculative_ids is not None: if batch.speculative_ids is not None: