Working state ? (Broke idefics1 temporarily).

This commit is contained in:
Nicolas Patry 2024-09-23 15:25:26 +02:00
parent 39d2073e93
commit 8abdd08ef4
No known key found for this signature in database
GPG Key ID: 64AF4752B2967863
9 changed files with 939 additions and 300 deletions

View File

@ -28,11 +28,17 @@ class ToolCall(BaseModel):
function: dict function: dict
class Chunk(BaseModel):
type: str
text: Optional[str] = None
image_url: Any = None
class Message(BaseModel): class Message(BaseModel):
# Role of the message sender # Role of the message sender
role: str role: str
# Content of the message # Content of the message
content: Optional[str] = None content: Optional[Union[str, List[Chunk]]] = None
# Optional name of the message sender # Optional name of the message sender
name: Optional[str] = None name: Optional[str] = None
# Tool calls associated with the chat completion # Tool calls associated with the chat completion

View File

@ -0,0 +1,106 @@
[
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727097740,
"id": "",
"model": "s0409/model-3",
"object": "chat.completion",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 24,
"total_tokens": 44
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727097740,
"id": "",
"model": "s0409/model-3",
"object": "chat.completion",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 24,
"total_tokens": 44
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727097740,
"id": "",
"model": "s0409/model-3",
"object": "chat.completion",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 24,
"total_tokens": 44
}
},
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727097740,
"id": "",
"model": "s0409/model-3",
"object": "chat.completion",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 24,
"total_tokens": 44
}
}
]

View File

@ -0,0 +1,26 @@
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"message": {
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
"name": null,
"role": "assistant",
"tool_calls": null
},
"usage": null
}
],
"created": 1727090615,
"id": "",
"model": "s0409/model-3",
"object": "chat.completion",
"system_fingerprint": "2.2.1-dev0-native",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 24,
"total_tokens": 44
}
}

View File

@ -0,0 +1,108 @@
import pytest
import base64
import asyncio
@pytest.fixture(scope="module")
def mllama_handle(launcher):
with launcher("s0409/model-3", num_shard=2) as handle:
yield handle
@pytest.fixture(scope="module")
async def mllama(mllama_handle):
await mllama_handle.health(300)
return mllama_handle.client
# TODO fix the server parsser to count inline image tokens correctly
def get_chicken():
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
def get_cow_beach():
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
@pytest.mark.asyncio
async def test_mllama_simpl(mllama, response_snapshot):
# chicken = get_chicken()
response = await mllama.chat(
max_tokens=20,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
assert response.usage == {
"completion_tokens": 20,
"prompt_tokens": 24,
"total_tokens": 44,
}
assert (
response.choices[0].message.content
== "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak"
)
assert response == response_snapshot
@pytest.mark.release
@pytest.mark.asyncio
async def test_mllama_load(mllama, generate_load, response_snapshot):
futures = [
mllama.chat(
max_tokens=20,
temperature=0.0,
messages=[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Can you tell me a very short story based on the image?",
},
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
},
},
],
},
],
)
for i in range(4)
]
responses = await asyncio.gather(*futures)
generated_texts = [response.choices[0].message.content for response in responses]
assert (
generated_texts[0]
== "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak"
)
assert len(generated_texts) == 4
assert generated_texts, all(
[text == generated_texts[0] for text in generated_texts]
)
assert responses == response_snapshot

View File

@ -146,6 +146,7 @@ pub enum Config {
ClipVisionModel(ClipVisionModel), ClipVisionModel(ClipVisionModel),
Mistral, Mistral,
Idefics, Idefics,
Mllama,
Idefics2(Idefics2), Idefics2(Idefics2),
Ssm, Ssm,
GptBigcode, GptBigcode,

View File

@ -567,6 +567,7 @@ fn image_tokens(
use HubPreprocessorConfig::*; use HubPreprocessorConfig::*;
match config { match config {
Idefics => "<image>".to_string(), Idefics => "<image>".to_string(),
Mllama => "<|image|>".to_string(),
Idefics2(config) => { Idefics2(config) => {
const FAKE: &str = "<fake_token_around_image>"; const FAKE: &str = "<fake_token_around_image>";
const IMAGE: &str = "<image>"; const IMAGE: &str = "<image>";
@ -618,7 +619,7 @@ fn prepare_input(
use Config::*; use Config::*;
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap()); static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
let (tokenizer_query, input_chunks) = match config { let (tokenizer_query, input_chunks) = match config {
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => { Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
let mut input_chunks = Vec::new(); let mut input_chunks = Vec::new();
let mut tokenizer_query = String::with_capacity(inputs.len()); let mut tokenizer_query = String::with_capacity(inputs.len());
let mut start = 0; let mut start = 0;

View File

@ -98,6 +98,8 @@ class IDEFICSSharded(IdeficsCausalLM):
else: else:
raise RuntimeError(f"Unsupported model type {config.model_type}") raise RuntimeError(f"Unsupported model type {config.model_type}")
self.config = config
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(IdeficsCausalLM, self).__init__( super(IdeficsCausalLM, self).__init__(
model_id=model_id, model_id=model_id,

View File

@ -6,8 +6,6 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from opentelemetry import trace from opentelemetry import trace
from transformers import ( from transformers import (
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
ProcessorMixin, ProcessorMixin,
) )
@ -38,6 +36,9 @@ class IdeficsCausalLMBatch(Batch):
attention_mask: torch.Tensor attention_mask: torch.Tensor
position_ids: torch.Tensor position_ids: torch.Tensor
pixel_values: Optional[torch.Tensor] pixel_values: Optional[torch.Tensor]
aspect_ratio_ids: Optional[torch.Tensor]
aspect_ratio_mask: Optional[torch.Tensor]
cross_attention_mask: Optional[torch.Tensor]
image_hidden_states: Optional[torch.Tensor] image_hidden_states: Optional[torch.Tensor]
image_attention_mask: Optional[torch.Tensor] image_attention_mask: Optional[torch.Tensor]
past_key_values: Optional[List[Tuple]] past_key_values: Optional[List[Tuple]]
@ -164,7 +165,7 @@ class IdeficsCausalLMBatch(Batch):
image = Image.open(BytesIO(chunk.image.data)) image = Image.open(BytesIO(chunk.image.data))
curr_images.append(image) curr_images.append(image)
# TODO unsure about BOS # TODO unsure about BOS
curr_text += "<|image|><|begin_of_text|>" curr_text += "<|image|>"
else: else:
raise RuntimeError(f"Invalid chunk type {chunk_type}") raise RuntimeError(f"Invalid chunk type {chunk_type}")
images.append(curr_images) images.append(curr_images)
@ -173,6 +174,8 @@ class IdeficsCausalLMBatch(Batch):
# The processor replaces the call to tokenizer, and # The processor replaces the call to tokenizer, and
# a/ takes care of fetching images from the URL # a/ takes care of fetching images from the URL
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model # b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
if all(len(im) == 0 for im in images):
images = None
tokenized_inputs = processor( tokenized_inputs = processor(
images=images, images=images,
text=texts, text=texts,
@ -205,7 +208,10 @@ class IdeficsCausalLMBatch(Batch):
# Do the same for image_attention_mask # Do the same for image_attention_mask
if pixel_values is None: if pixel_values is None:
image_attention_mask = None image_attention_mask = None
else: aspect_ratio_ids = None
aspect_ratio_mask = None
cross_attention_mask = None
elif "image_attention_mask" in tokenized_inputs:
image_attention_mask = input_ids.new_zeros( image_attention_mask = input_ids.new_zeros(
( (
pb.size, pb.size,
@ -216,6 +222,19 @@ class IdeficsCausalLMBatch(Batch):
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
"image_attention_mask" "image_attention_mask"
] ]
aspect_ratio_ids = None
aspect_ratio_mask = None
cross_attention_mask = None
else:
image_attention_mask = None
aspect_ratio_ids = tokenized_inputs["aspect_ratio_ids"]
aspect_ratio_mask = tokenized_inputs["aspect_ratio_mask"]
cross_attention_mask = tokenized_inputs["cross_attention_mask"]
pixel_values = pixel_values.to(dtype=dtype)
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
tokenized_inputs["input_ids"] = tokenized_inputs["input_ids"].clamp(
max=processor.tokenizer.vocab_size - 1
)
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
@ -245,6 +264,9 @@ class IdeficsCausalLMBatch(Batch):
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
max_tokens=max_tokens, max_tokens=max_tokens,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
cross_attention_mask=cross_attention_mask,
) )
@tracer.start_as_current_span("filter") @tracer.start_as_current_span("filter")
@ -308,15 +330,21 @@ class IdeficsCausalLMBatch(Batch):
+ new_padding_right_offset, + new_padding_right_offset,
] ]
# Do the same for pixel_values and image_attention_mask # Do the same for pixel_values and image_attention_mask
pixel_values = self.pixel_values[keep_indices] if self.pixel_values is not None:
self.image_attention_mask = self.image_attention_mask[ pixel_values = self.pixel_values[keep_indices]
keep_indices, else:
-(self.padding_right_offset + max_input_length) : ( pixel_values = None
self.image_attention_mask.shape[1] - self.padding_right_offset
) if self.image_attention_mask is not None:
+ new_padding_right_offset, self.image_attention_mask = self.image_attention_mask[
:, keep_indices,
] -(self.padding_right_offset + max_input_length) : (
self.image_attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
:,
]
if self.image_hidden_states is None: if self.image_hidden_states is None:
image_hidden_states = None image_hidden_states = None
else: else:
@ -359,6 +387,9 @@ class IdeficsCausalLMBatch(Batch):
self.max_input_length = max_input_length self.max_input_length = max_input_length
self.padding_right_offset = new_padding_right_offset self.padding_right_offset = new_padding_right_offset
self.max_tokens = max_tokens self.max_tokens = max_tokens
self.aspect_ratio_ids = None
self.aspect_ratio_mask = None
self.cross_attention_mask = None
return self return self
@ -376,7 +407,8 @@ class IdeficsCausalLMBatch(Batch):
for batch in batches: for batch in batches:
total_batch_size += len(batch) total_batch_size += len(batch)
max_input_length = max(max_input_length, batch.max_input_length) max_input_length = max(max_input_length, batch.max_input_length)
max_num_images = max(max_num_images, batch.pixel_values.size(1)) if batch.pixel_values is not None:
max_num_images = max(max_num_images, batch.pixel_values.size(1))
padding_right_offset = max(padding_right_offset, batch.padding_right_offset) padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
# Batch attributes # Batch attributes
@ -439,16 +471,19 @@ class IdeficsCausalLMBatch(Batch):
(total_batch_size, max_input_length + padding_right_offset), (total_batch_size, max_input_length + padding_right_offset),
) )
curr_batch_max_num_images = batch.pixel_values.size(1) if batch.pixel_values is not None:
if pixel_values is None: curr_batch_max_num_images = batch.pixel_values.size(1)
pixel_values = batch.pixel_values.new_zeros( if pixel_values is None:
(total_batch_size, max_num_images, 3, 224, 224) pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224)
)
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
batch.pixel_values
) )
pixel_values[start_index:end_index, :curr_batch_max_num_images] = ( else:
batch.pixel_values pixel_values = None
)
if image_attention_mask is None: if image_attention_mask is None and batch.image_attention_mask is not None:
image_attention_mask = batch.image_attention_mask.new_zeros( image_attention_mask = batch.image_attention_mask.new_zeros(
( (
total_batch_size, total_batch_size,
@ -472,13 +507,14 @@ class IdeficsCausalLMBatch(Batch):
:, :,
batch_left_offset : -batch.padding_right_offset, batch_left_offset : -batch.padding_right_offset,
] ]
image_attention_mask[ if batch.image_attention_mask is not None:
start_index:end_index, image_attention_mask[
left_offset:-padding_right_offset, start_index:end_index,
:curr_batch_max_num_images, left_offset:-padding_right_offset,
] = batch.image_attention_mask[ :curr_batch_max_num_images,
:, batch_left_offset : -batch.padding_right_offset, : ] = batch.image_attention_mask[
] :, batch_left_offset : -batch.padding_right_offset, :
]
# Create empty tensor # Create empty tensor
# position_ids is always of shape [batch_size, 1] # position_ids is always of shape [batch_size, 1]
@ -531,7 +567,20 @@ class IdeficsCausalLMBatch(Batch):
# Iterate over attention layers # Iterate over attention layers
# Concatenate past key values layer by layer to allow incremental garbage collection # Concatenate past key values layer by layer to allow incremental garbage collection
for j in range(len(first_past_kvs)): for j in range(len(first_past_kvs)):
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape) _, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
if seqlen > max_input_length:
# XXX: This is probably a cross attention key value
# If not this is ok
_padded_past_keys_shape = (
total_batch_size,
_num_heads,
seqlen,
_head_dim,
)
else:
_padded_past_keys_shape = padded_past_keys_shape
padded_past_keys = first_past_kvs[j][0].new_zeros(_padded_past_keys_shape)
start_index = 0 start_index = 0
for batch in batches: for batch in batches:
past_keys = batch.past_key_values[j][0] past_keys = batch.past_key_values[j][0]
@ -542,6 +591,9 @@ class IdeficsCausalLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the keys to remove the padding from previous batches # We slice the keys to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
if past_keys.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama
past_seq_len = past_keys.shape[2]
if batch.keys_head_dim_last: if batch.keys_head_dim_last:
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = ( padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
past_keys[:, :, -past_seq_len:, :] past_keys[:, :, -past_seq_len:, :]
@ -555,8 +607,20 @@ class IdeficsCausalLMBatch(Batch):
start_index = end_index start_index = end_index
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
if seqlen > max_input_length:
# XXX: This is probably a cross attention key value
# If not this is ok
_padded_past_values_shape = (
total_batch_size,
_num_heads,
seqlen,
_head_dim,
)
else:
_padded_past_values_shape = padded_past_values_shape
padded_past_values = first_past_kvs[j][1].new_zeros( padded_past_values = first_past_kvs[j][1].new_zeros(
padded_past_values_shape _padded_past_values_shape
) )
start_index = 0 start_index = 0
for batch in batches: for batch in batches:
@ -568,6 +632,9 @@ class IdeficsCausalLMBatch(Batch):
end_index = start_index + len(batch) end_index = start_index + len(batch)
# We slice the past values to remove the padding from previous batches # We slice the past values to remove the padding from previous batches
past_seq_len = batch.max_input_length - 1 past_seq_len = batch.max_input_length - 1
if past_values.shape[2] > past_seq_len:
# XXX: This is a cross attention kv in mllama
past_seq_len = past_values.shape[2]
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = ( padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
past_values[:, :, -past_seq_len:, :] past_values[:, :, -past_seq_len:, :]
) )
@ -599,6 +666,10 @@ class IdeficsCausalLMBatch(Batch):
padding_right_offset=padding_right_offset, padding_right_offset=padding_right_offset,
keys_head_dim_last=batches[0].keys_head_dim_last, keys_head_dim_last=batches[0].keys_head_dim_last,
max_tokens=max_tokens, max_tokens=max_tokens,
# No need to keep this around. for Mllamma
aspect_ratio_ids=None,
aspect_ratio_mask=None,
cross_attention_mask=None,
) )
def __len__(self): def __len__(self):
@ -606,77 +677,6 @@ class IdeficsCausalLMBatch(Batch):
class IdeficsCausalLM(Model): class IdeficsCausalLM(Model):
def __init__(
self,
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
self.quantize = quantize
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if dtype is None else dtype
else:
if quantize:
raise ValueError("quantization is not available on CPU")
device = torch.device("cpu")
dtype = torch.float32 if dtype is None else dtype
tokenizer = AutoTokenizer.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
self.processor = AutoProcessor.from_pretrained(
model_id,
revision=revision,
padding_side="left",
truncation_side="left",
trust_remote_code=trust_remote_code,
)
model = IdeficsForVisionText2Text.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map=(
"auto"
if torch.cuda.is_available() and torch.cuda.device_count() > 1
else None
),
load_in_8bit=quantize == "bitsandbytes",
trust_remote_code=trust_remote_code,
)
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
model = model.cuda()
if tokenizer.pad_token_id is None:
if model.config.pad_token_id is not None:
tokenizer.pad_token_id = model.config.pad_token_id
elif model.config.eos_token_id is not None:
tokenizer.pad_token_id = model.config.eos_token_id
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
tokenizer.add_special_tokens({"pad_token": "<unk>"})
super(IdeficsCausalLM, self).__init__(
model_id=model_id,
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
)
@property @property
def batch_type(self) -> Type[IdeficsCausalLMBatch]: def batch_type(self) -> Type[IdeficsCausalLMBatch]:
return IdeficsCausalLMBatch return IdeficsCausalLMBatch
@ -690,6 +690,9 @@ class IdeficsCausalLM(Model):
image_hidden_states, image_hidden_states,
image_attention_mask, image_attention_mask,
past_key_values: Optional = None, past_key_values: Optional = None,
aspect_ratio_ids=None,
aspect_ratio_mask=None,
cross_attention_mask=None,
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
# Model Forward # Model Forward
kwargs = { kwargs = {
@ -699,18 +702,23 @@ class IdeficsCausalLM(Model):
"image_hidden_states": image_hidden_states, "image_hidden_states": image_hidden_states,
"image_attention_mask": image_attention_mask, "image_attention_mask": image_attention_mask,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": True,
"return_dict": True,
} }
if self.has_position_ids: if self.has_position_ids:
kwargs["position_ids"] = position_ids kwargs["position_ids"] = position_ids
if aspect_ratio_ids is not None:
kwargs["aspect_ratio_ids"] = aspect_ratio_ids
if aspect_ratio_mask is not None:
kwargs["aspect_ratio_mask"] = aspect_ratio_mask
if cross_attention_mask is not None:
kwargs["cross_attention_mask"] = cross_attention_mask
outputs, speculative_logits = self.model.forward(**kwargs) outputs, speculative_logits = self.model.forward(**kwargs)
assert outputs.past_key_values is not None
return ( return (
outputs.logits, outputs.logits,
speculative_logits, speculative_logits,
outputs.past_key_values, outputs.past_key_values,
outputs.image_hidden_states, getattr(outputs, "image_hidden_states", None),
) )
@tracer.start_as_current_span("generate_token") @tracer.start_as_current_span("generate_token")
@ -745,9 +753,13 @@ class IdeficsCausalLM(Model):
image_hidden_states=batch.image_hidden_states, image_hidden_states=batch.image_hidden_states,
image_attention_mask=image_attention_mask, image_attention_mask=image_attention_mask,
past_key_values=batch.past_key_values, past_key_values=batch.past_key_values,
aspect_ratio_ids=batch.aspect_ratio_ids,
aspect_ratio_mask=batch.aspect_ratio_mask,
cross_attention_mask=batch.cross_attention_mask,
) )
# Hardcoded remove image tokens # Hardcoded remove image tokens
logits[:, 32000:32001] = torch.finfo(logits.dtype).min if self.config.model_type == "idefics":
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
start_decode = time.time_ns() start_decode = time.time_ns()
@ -890,10 +902,13 @@ class IdeficsCausalLM(Model):
batch.input_ids = batch.input_ids[:, :1] batch.input_ids = batch.input_ids[:, :1]
# Update attention_mask as we added a new token to input_ids # Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1 # batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.image_attention_mask[:, -batch.padding_right_offset, :] = ( if batch.image_attention_mask is not None:
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
) batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
)
if batch.cross_attention_mask is not None:
batch.cross_attention_mask = batch.cross_attention_mask[:, -1:]
# Decrease right offset # Decrease right offset
batch.padding_right_offset -= 1 batch.padding_right_offset -= 1
@ -903,7 +918,8 @@ class IdeficsCausalLM(Model):
# Update past key values # Update past key values
batch.past_key_values = past batch.past_key_values = past
batch.image_hidden_states = image_hidden_states batch.image_hidden_states = image_hidden_states
if self.model.config.model_type == "mllama":
batch.pixel_values = None
forward_ns = start_decode - start forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode decode_ns = time.time_ns() - start_decode
return generations, batch, (forward_ns, decode_ns) return generations, batch, (forward_ns, decode_ns)