fix: update all models forwards to include adapter_data
This commit is contained in:
parent
1deb372564
commit
101b95adc4
|
@ -514,6 +514,7 @@ class FlashCohereForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
|
|
@ -724,6 +724,7 @@ class FlashDbrxForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
hidden_states = self.model(
|
||||
input_ids,
|
||||
|
|
|
@ -460,6 +460,7 @@ class FlashGemmaForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
input_embeds = self.embed_tokens(input_ids)
|
||||
hidden_states = self.model(
|
||||
|
|
|
@ -445,6 +445,7 @@ class FlashGPT2ForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
token_embeds = self.embed_tokens(input_ids)
|
||||
position_embeds = self.embed_positions(position_ids)
|
||||
|
|
|
@ -638,6 +638,7 @@ class FlashMixtralForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
|
|
|
@ -74,6 +74,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
|
|||
# Unused here
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
# TODO This is odd but apparently pali gemma position ids start at 1.
|
||||
|
|
|
@ -378,6 +378,7 @@ class Qwen2ForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
|
|
|
@ -741,6 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
|
|
|
@ -178,6 +178,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None and len(pixel_values) > 0:
|
||||
|
|
Loading…
Reference in New Issue