Fixing seqlen with the new vlms.

This commit is contained in:
Nicolas Patry 2024-08-27 18:16:35 +02:00
parent 9dacac3b15
commit e0069a3a26
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
4 changed files with 18 additions and 8 deletions

View File

@ -19,6 +19,7 @@ from torch import nn
from typing import Optional, List, Tuple
from text_generation_server.layers.tensor_parallel import TensorParallelColumnLinear
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
@ -70,7 +71,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
@ -107,7 +108,7 @@ class PaliGemmaForConditionalGeneration(nn.Module):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
)

View File

@ -25,6 +25,7 @@ from transformers.activations import ACT2FN
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
)
from text_generation_server.layers.attention import Seqlen
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from text_generation_server.layers import (
@ -740,7 +741,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -826,7 +827,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,

View File

@ -23,6 +23,7 @@ from torch import nn
from transformers.activations import ACT2FN
from transformers.image_processing_utils import select_best_resolution
from text_generation_server.layers.attention import Seqlen
from text_generation_server.models.custom_modeling.vlm import (
load_text_model,
load_vision_model,
@ -170,7 +171,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
block_tables: torch.Tensor,
slots: torch.Tensor,
input_lengths: torch.Tensor,
seqlen: Seqlen,
max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None,
@ -276,7 +277,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
true_max_s=max_s,
prefill_cache_indices=None,

View File

@ -372,7 +372,14 @@ class VlmCausalLM(FlashCausalLM):
prefix_lens=batch.prefix_lens,
prefix_lens_tensor=prefix_lens_tensor,
):
input_lengths = Seqlen(input_lengths=input_lengths)
max_k = (input_lengths + prefix_lens_tensor).max().item()
seqlen = Seqlen(
input_lengths=input_lengths,
prefix_lengths=prefix_lens_tensor,
cu_seqlen_q=cu_seqlen_prefill,
max_q=max_s,
max_k=max_k,
)
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
@ -380,7 +387,7 @@ class VlmCausalLM(FlashCausalLM):
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
seqlen=seqlen,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,