fix crash in multi-modal (#2245)

* fix crash in multi-modal

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* update according to review comment

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix llava_next regression in latest main

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2024-07-24 16:39:08 +08:00 committed by GitHub
parent a895029424
commit 5ad39dd3c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 5 additions and 1 deletions

View File

@ -424,7 +424,7 @@ class FlashLlamaModel(torch.nn.Module):
FlashLlamaLayer( FlashLlamaLayer(
index=0, index=0,
prefix=( prefix=(
"model.layers.0" if not prefix else "{prefix}.model.layers.0" "model.layers.0" if not prefix else f"{prefix}.model.layers.0"
), ),
config=config, config=config,
weights=weights, weights=weights,

View File

@ -832,6 +832,7 @@ class Idefics2ForConditionalGeneration(nn.Module):
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -280,6 +280,7 @@ class LlavaNextForConditionalGeneration(nn.Module):
max_s=max_s, max_s=max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -14,6 +14,7 @@ from text_generation_server.models.flash_causal_lm import (
) )
from text_generation_server.utils.log import log_master from text_generation_server.utils.log import log_master
from transformers import AutoProcessor from transformers import AutoProcessor
from text_generation_server.layers.attention import Seqlen
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -348,6 +349,7 @@ class VlmCausalLM(FlashCausalLM):
else: else:
cuda_graph = None cuda_graph = None
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
input_lengths = Seqlen(input_lengths=input_lengths)
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
position_ids=position_ids, position_ids=position_ids,