diff --git a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py index e08a2aad..d044b492 100644 --- a/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_pali_gemma_modeling.py @@ -19,6 +19,7 @@ from torch import nn from typing import Optional, List, Tuple from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear +from text_generation_server.layers.attention import Seqlen from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None, @@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, ) diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 7e4deaf8..a829c374 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -25,6 +25,7 @@ from transformers.activations import ACT2FN from text_generation_server.models.custom_modeling.vlm import ( load_text_model, ) +from text_generation_server.layers.attention import Seqlen from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask from text_generation_server.layers import ( @@ -740,7 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -826,7 +827,7 @@ class Idefics2ForConditionalGeneration(nn.Module): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index 29f5b9c7..32e9d334 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -23,6 +23,7 @@ from torch import nn from transformers.activations import ACT2FN from transformers.image_processing_utils import select_best_resolution +from text_generation_server.layers.attention import Seqlen from text_generation_server.models.custom_modeling.vlm import ( load_text_model, load_vision_model, @@ -170,7 +171,7 @@ class LlavaNextForConditionalGeneration(nn.Module): kv_cache: List[Tuple[torch.Tensor, torch.Tensor]], block_tables: torch.Tensor, slots: torch.Tensor, - input_lengths: torch.Tensor, + seqlen: Seqlen, max_s: int, prefill_cache_indices: Optional[torch.Tensor], lm_head_indices: Optional[torch.Tensor] = None, @@ -276,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, true_max_s=max_s, prefill_cache_indices=None, diff --git a/server/text_generation_server/models/vlm_causal_lm.py b/server/text_generation_server/models/vlm_causal_lm.py index 2ed1a119..d6cb36fa 100644 --- a/server/text_generation_server/models/vlm_causal_lm.py +++ b/server/text_generation_server/models/vlm_causal_lm.py @@ -372,7 +372,14 @@ class VlmCausalLM(FlashCausalLM): prefix_lens=batch.prefix_lens, prefix_lens_tensor=prefix_lens_tensor, ): - input_lengths = Seqlen(input_lengths=input_lengths) + max_k = (input_lengths + prefix_lens_tensor).max().item() + seqlen = Seqlen( + input_lengths=input_lengths, + prefix_lengths=prefix_lens_tensor, + cu_seqlen_q=cu_seqlen_prefill, + max_q=max_s, + max_k=max_k, + ) logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, @@ -380,7 +387,7 @@ class VlmCausalLM(FlashCausalLM): kv_cache=kv_cache, block_tables=block_tables, slots=slots, - input_lengths=input_lengths, + seqlen=seqlen, max_s=max_s, prefill_cache_indices=batch.prefill_cache_indices, lm_head_indices=lm_head_indices,