Working state ? (Broke idefics1 temporarily).
This commit is contained in:
parent
39d2073e93
commit
8abdd08ef4
|
@ -28,11 +28,17 @@ class ToolCall(BaseModel):
|
|||
function: dict
|
||||
|
||||
|
||||
class Chunk(BaseModel):
|
||||
type: str
|
||||
text: Optional[str] = None
|
||||
image_url: Any = None
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
# Role of the message sender
|
||||
role: str
|
||||
# Content of the message
|
||||
content: Optional[str] = None
|
||||
content: Optional[Union[str, List[Chunk]]] = None
|
||||
# Optional name of the message sender
|
||||
name: Optional[str] = None
|
||||
# Tool calls associated with the chat completion
|
||||
|
|
|
@ -0,0 +1,106 @@
|
|||
[
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727097740,
|
||||
"id": "",
|
||||
"model": "s0409/model-3",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 24,
|
||||
"total_tokens": 44
|
||||
}
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727097740,
|
||||
"id": "",
|
||||
"model": "s0409/model-3",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 24,
|
||||
"total_tokens": 44
|
||||
}
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727097740,
|
||||
"id": "",
|
||||
"model": "s0409/model-3",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 24,
|
||||
"total_tokens": 44
|
||||
}
|
||||
},
|
||||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727097740,
|
||||
"id": "",
|
||||
"model": "s0409/model-3",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 24,
|
||||
"total_tokens": 44
|
||||
}
|
||||
}
|
||||
]
|
|
@ -0,0 +1,26 @@
|
|||
{
|
||||
"choices": [
|
||||
{
|
||||
"finish_reason": "length",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"message": {
|
||||
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
|
||||
"name": null,
|
||||
"role": "assistant",
|
||||
"tool_calls": null
|
||||
},
|
||||
"usage": null
|
||||
}
|
||||
],
|
||||
"created": 1727090615,
|
||||
"id": "",
|
||||
"model": "s0409/model-3",
|
||||
"object": "chat.completion",
|
||||
"system_fingerprint": "2.2.1-dev0-native",
|
||||
"usage": {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 24,
|
||||
"total_tokens": 44
|
||||
}
|
||||
}
|
|
@ -0,0 +1,108 @@
|
|||
import pytest
|
||||
import base64
|
||||
import asyncio
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def mllama_handle(launcher):
|
||||
with launcher("s0409/model-3", num_shard=2) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def mllama(mllama_handle):
|
||||
await mllama_handle.health(300)
|
||||
return mllama_handle.client
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
def get_cow_beach():
|
||||
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mllama_simpl(mllama, response_snapshot):
|
||||
# chicken = get_chicken()
|
||||
response = await mllama.chat(
|
||||
max_tokens=20,
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Can you tell me a very short story based on the image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert response.usage == {
|
||||
"completion_tokens": 20,
|
||||
"prompt_tokens": 24,
|
||||
"total_tokens": 44,
|
||||
}
|
||||
assert (
|
||||
response.choices[0].message.content
|
||||
== "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak"
|
||||
)
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.release
|
||||
@pytest.mark.asyncio
|
||||
async def test_mllama_load(mllama, generate_load, response_snapshot):
|
||||
futures = [
|
||||
mllama.chat(
|
||||
max_tokens=20,
|
||||
temperature=0.0,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Can you tell me a very short story based on the image?",
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
)
|
||||
for i in range(4)
|
||||
]
|
||||
responses = await asyncio.gather(*futures)
|
||||
|
||||
generated_texts = [response.choices[0].message.content for response in responses]
|
||||
|
||||
assert (
|
||||
generated_texts[0]
|
||||
== "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak"
|
||||
)
|
||||
assert len(generated_texts) == 4
|
||||
assert generated_texts, all(
|
||||
[text == generated_texts[0] for text in generated_texts]
|
||||
)
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -146,6 +146,7 @@ pub enum Config {
|
|||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
Idefics,
|
||||
Mllama,
|
||||
Idefics2(Idefics2),
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
|
|
|
@ -567,6 +567,7 @@ fn image_tokens(
|
|||
use HubPreprocessorConfig::*;
|
||||
match config {
|
||||
Idefics => "<image>".to_string(),
|
||||
Mllama => "<|image|>".to_string(),
|
||||
Idefics2(config) => {
|
||||
const FAKE: &str = "<fake_token_around_image>";
|
||||
const IMAGE: &str = "<image>";
|
||||
|
@ -618,7 +619,7 @@ fn prepare_input(
|
|||
use Config::*;
|
||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||
let (tokenizer_query, input_chunks) = match config {
|
||||
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||
let mut input_chunks = Vec::new();
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -98,6 +98,8 @@ class IDEFICSSharded(IdeficsCausalLM):
|
|||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
self.config = config
|
||||
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(IdeficsCausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
|
|
|
@ -6,8 +6,6 @@ import time
|
|||
from dataclasses import dataclass
|
||||
from opentelemetry import trace
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
)
|
||||
|
@ -38,6 +36,9 @@ class IdeficsCausalLMBatch(Batch):
|
|||
attention_mask: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
pixel_values: Optional[torch.Tensor]
|
||||
aspect_ratio_ids: Optional[torch.Tensor]
|
||||
aspect_ratio_mask: Optional[torch.Tensor]
|
||||
cross_attention_mask: Optional[torch.Tensor]
|
||||
image_hidden_states: Optional[torch.Tensor]
|
||||
image_attention_mask: Optional[torch.Tensor]
|
||||
past_key_values: Optional[List[Tuple]]
|
||||
|
@ -164,7 +165,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
image = Image.open(BytesIO(chunk.image.data))
|
||||
curr_images.append(image)
|
||||
# TODO unsure about BOS
|
||||
curr_text += "<|image|><|begin_of_text|>"
|
||||
curr_text += "<|image|>"
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||
images.append(curr_images)
|
||||
|
@ -173,6 +174,8 @@ class IdeficsCausalLMBatch(Batch):
|
|||
# The processor replaces the call to tokenizer, and
|
||||
# a/ takes care of fetching images from the URL
|
||||
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
|
||||
if all(len(im) == 0 for im in images):
|
||||
images = None
|
||||
tokenized_inputs = processor(
|
||||
images=images,
|
||||
text=texts,
|
||||
|
@ -205,7 +208,10 @@ class IdeficsCausalLMBatch(Batch):
|
|||
# Do the same for image_attention_mask
|
||||
if pixel_values is None:
|
||||
image_attention_mask = None
|
||||
else:
|
||||
aspect_ratio_ids = None
|
||||
aspect_ratio_mask = None
|
||||
cross_attention_mask = None
|
||||
elif "image_attention_mask" in tokenized_inputs:
|
||||
image_attention_mask = input_ids.new_zeros(
|
||||
(
|
||||
pb.size,
|
||||
|
@ -216,6 +222,19 @@ class IdeficsCausalLMBatch(Batch):
|
|||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||
"image_attention_mask"
|
||||
]
|
||||
aspect_ratio_ids = None
|
||||
aspect_ratio_mask = None
|
||||
cross_attention_mask = None
|
||||
else:
|
||||
image_attention_mask = None
|
||||
aspect_ratio_ids = tokenized_inputs["aspect_ratio_ids"]
|
||||
aspect_ratio_mask = tokenized_inputs["aspect_ratio_mask"]
|
||||
cross_attention_mask = tokenized_inputs["cross_attention_mask"]
|
||||
pixel_values = pixel_values.to(dtype=dtype)
|
||||
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
|
||||
tokenized_inputs["input_ids"] = tokenized_inputs["input_ids"].clamp(
|
||||
max=processor.tokenizer.vocab_size - 1
|
||||
)
|
||||
|
||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||
|
@ -245,6 +264,9 @@ class IdeficsCausalLMBatch(Batch):
|
|||
max_input_length=max_input_length.item(),
|
||||
padding_right_offset=padding_right_offset,
|
||||
max_tokens=max_tokens,
|
||||
aspect_ratio_ids=aspect_ratio_ids,
|
||||
aspect_ratio_mask=aspect_ratio_mask,
|
||||
cross_attention_mask=cross_attention_mask,
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("filter")
|
||||
|
@ -308,7 +330,12 @@ class IdeficsCausalLMBatch(Batch):
|
|||
+ new_padding_right_offset,
|
||||
]
|
||||
# Do the same for pixel_values and image_attention_mask
|
||||
if self.pixel_values is not None:
|
||||
pixel_values = self.pixel_values[keep_indices]
|
||||
else:
|
||||
pixel_values = None
|
||||
|
||||
if self.image_attention_mask is not None:
|
||||
self.image_attention_mask = self.image_attention_mask[
|
||||
keep_indices,
|
||||
-(self.padding_right_offset + max_input_length) : (
|
||||
|
@ -317,6 +344,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
+ new_padding_right_offset,
|
||||
:,
|
||||
]
|
||||
|
||||
if self.image_hidden_states is None:
|
||||
image_hidden_states = None
|
||||
else:
|
||||
|
@ -359,6 +387,9 @@ class IdeficsCausalLMBatch(Batch):
|
|||
self.max_input_length = max_input_length
|
||||
self.padding_right_offset = new_padding_right_offset
|
||||
self.max_tokens = max_tokens
|
||||
self.aspect_ratio_ids = None
|
||||
self.aspect_ratio_mask = None
|
||||
self.cross_attention_mask = None
|
||||
|
||||
return self
|
||||
|
||||
|
@ -376,6 +407,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
for batch in batches:
|
||||
total_batch_size += len(batch)
|
||||
max_input_length = max(max_input_length, batch.max_input_length)
|
||||
if batch.pixel_values is not None:
|
||||
max_num_images = max(max_num_images, batch.pixel_values.size(1))
|
||||
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||
|
||||
|
@ -439,6 +471,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
(total_batch_size, max_input_length + padding_right_offset),
|
||||
)
|
||||
|
||||
if batch.pixel_values is not None:
|
||||
curr_batch_max_num_images = batch.pixel_values.size(1)
|
||||
if pixel_values is None:
|
||||
pixel_values = batch.pixel_values.new_zeros(
|
||||
|
@ -447,8 +480,10 @@ class IdeficsCausalLMBatch(Batch):
|
|||
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
|
||||
batch.pixel_values
|
||||
)
|
||||
else:
|
||||
pixel_values = None
|
||||
|
||||
if image_attention_mask is None:
|
||||
if image_attention_mask is None and batch.image_attention_mask is not None:
|
||||
image_attention_mask = batch.image_attention_mask.new_zeros(
|
||||
(
|
||||
total_batch_size,
|
||||
|
@ -472,6 +507,7 @@ class IdeficsCausalLMBatch(Batch):
|
|||
:,
|
||||
batch_left_offset : -batch.padding_right_offset,
|
||||
]
|
||||
if batch.image_attention_mask is not None:
|
||||
image_attention_mask[
|
||||
start_index:end_index,
|
||||
left_offset:-padding_right_offset,
|
||||
|
@ -531,7 +567,20 @@ class IdeficsCausalLMBatch(Batch):
|
|||
# Iterate over attention layers
|
||||
# Concatenate past key values layer by layer to allow incremental garbage collection
|
||||
for j in range(len(first_past_kvs)):
|
||||
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
|
||||
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
|
||||
if seqlen > max_input_length:
|
||||
# XXX: This is probably a cross attention key value
|
||||
# If not this is ok
|
||||
_padded_past_keys_shape = (
|
||||
total_batch_size,
|
||||
_num_heads,
|
||||
seqlen,
|
||||
_head_dim,
|
||||
)
|
||||
else:
|
||||
_padded_past_keys_shape = padded_past_keys_shape
|
||||
|
||||
padded_past_keys = first_past_kvs[j][0].new_zeros(_padded_past_keys_shape)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
past_keys = batch.past_key_values[j][0]
|
||||
|
@ -542,6 +591,9 @@ class IdeficsCausalLMBatch(Batch):
|
|||
end_index = start_index + len(batch)
|
||||
# We slice the keys to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if past_keys.shape[2] > past_seq_len:
|
||||
# XXX: This is a cross attention kv in mllama
|
||||
past_seq_len = past_keys.shape[2]
|
||||
if batch.keys_head_dim_last:
|
||||
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_keys[:, :, -past_seq_len:, :]
|
||||
|
@ -555,8 +607,20 @@ class IdeficsCausalLMBatch(Batch):
|
|||
|
||||
start_index = end_index
|
||||
|
||||
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
|
||||
if seqlen > max_input_length:
|
||||
# XXX: This is probably a cross attention key value
|
||||
# If not this is ok
|
||||
_padded_past_values_shape = (
|
||||
total_batch_size,
|
||||
_num_heads,
|
||||
seqlen,
|
||||
_head_dim,
|
||||
)
|
||||
else:
|
||||
_padded_past_values_shape = padded_past_values_shape
|
||||
padded_past_values = first_past_kvs[j][1].new_zeros(
|
||||
padded_past_values_shape
|
||||
_padded_past_values_shape
|
||||
)
|
||||
start_index = 0
|
||||
for batch in batches:
|
||||
|
@ -568,6 +632,9 @@ class IdeficsCausalLMBatch(Batch):
|
|||
end_index = start_index + len(batch)
|
||||
# We slice the past values to remove the padding from previous batches
|
||||
past_seq_len = batch.max_input_length - 1
|
||||
if past_values.shape[2] > past_seq_len:
|
||||
# XXX: This is a cross attention kv in mllama
|
||||
past_seq_len = past_values.shape[2]
|
||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||
past_values[:, :, -past_seq_len:, :]
|
||||
)
|
||||
|
@ -599,6 +666,10 @@ class IdeficsCausalLMBatch(Batch):
|
|||
padding_right_offset=padding_right_offset,
|
||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||
max_tokens=max_tokens,
|
||||
# No need to keep this around. for Mllamma
|
||||
aspect_ratio_ids=None,
|
||||
aspect_ratio_mask=None,
|
||||
cross_attention_mask=None,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
|
@ -606,77 +677,6 @@ class IdeficsCausalLMBatch(Batch):
|
|||
|
||||
|
||||
class IdeficsCausalLM(Model):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.quantize = quantize
|
||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||
IdeficsForVisionText2Text,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if dtype is None else dtype
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization is not available on CPU")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32 if dtype is None else dtype
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
padding_side="left",
|
||||
truncation_side="left",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
model = IdeficsForVisionText2Text.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
device_map=(
|
||||
"auto"
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
||||
else None
|
||||
),
|
||||
load_in_8bit=quantize == "bitsandbytes",
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
||||
model = model.cuda()
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
if model.config.pad_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.pad_token_id
|
||||
elif model.config.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = model.config.eos_token_id
|
||||
elif tokenizer.eos_token_id is not None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
else:
|
||||
tokenizer.add_special_tokens({"pad_token": "<unk>"})
|
||||
|
||||
super(IdeficsCausalLM, self).__init__(
|
||||
model_id=model_id,
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
requires_padding=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
@property
|
||||
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
|
||||
return IdeficsCausalLMBatch
|
||||
|
@ -690,6 +690,9 @@ class IdeficsCausalLM(Model):
|
|||
image_hidden_states,
|
||||
image_attention_mask,
|
||||
past_key_values: Optional = None,
|
||||
aspect_ratio_ids=None,
|
||||
aspect_ratio_mask=None,
|
||||
cross_attention_mask=None,
|
||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Model Forward
|
||||
kwargs = {
|
||||
|
@ -699,18 +702,23 @@ class IdeficsCausalLM(Model):
|
|||
"image_hidden_states": image_hidden_states,
|
||||
"image_attention_mask": image_attention_mask,
|
||||
"past_key_values": past_key_values,
|
||||
"use_cache": True,
|
||||
"return_dict": True,
|
||||
}
|
||||
if self.has_position_ids:
|
||||
kwargs["position_ids"] = position_ids
|
||||
if aspect_ratio_ids is not None:
|
||||
kwargs["aspect_ratio_ids"] = aspect_ratio_ids
|
||||
if aspect_ratio_mask is not None:
|
||||
kwargs["aspect_ratio_mask"] = aspect_ratio_mask
|
||||
if cross_attention_mask is not None:
|
||||
kwargs["cross_attention_mask"] = cross_attention_mask
|
||||
|
||||
outputs, speculative_logits = self.model.forward(**kwargs)
|
||||
assert outputs.past_key_values is not None
|
||||
return (
|
||||
outputs.logits,
|
||||
speculative_logits,
|
||||
outputs.past_key_values,
|
||||
outputs.image_hidden_states,
|
||||
getattr(outputs, "image_hidden_states", None),
|
||||
)
|
||||
|
||||
@tracer.start_as_current_span("generate_token")
|
||||
|
@ -745,8 +753,12 @@ class IdeficsCausalLM(Model):
|
|||
image_hidden_states=batch.image_hidden_states,
|
||||
image_attention_mask=image_attention_mask,
|
||||
past_key_values=batch.past_key_values,
|
||||
aspect_ratio_ids=batch.aspect_ratio_ids,
|
||||
aspect_ratio_mask=batch.aspect_ratio_mask,
|
||||
cross_attention_mask=batch.cross_attention_mask,
|
||||
)
|
||||
# Hardcoded remove image tokens
|
||||
if self.config.model_type == "idefics":
|
||||
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
||||
|
||||
start_decode = time.time_ns()
|
||||
|
@ -890,10 +902,13 @@ class IdeficsCausalLM(Model):
|
|||
batch.input_ids = batch.input_ids[:, :1]
|
||||
|
||||
# Update attention_mask as we added a new token to input_ids
|
||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
# batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||
if batch.image_attention_mask is not None:
|
||||
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
|
||||
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
|
||||
)
|
||||
if batch.cross_attention_mask is not None:
|
||||
batch.cross_attention_mask = batch.cross_attention_mask[:, -1:]
|
||||
# Decrease right offset
|
||||
batch.padding_right_offset -= 1
|
||||
|
||||
|
@ -903,7 +918,8 @@ class IdeficsCausalLM(Model):
|
|||
# Update past key values
|
||||
batch.past_key_values = past
|
||||
batch.image_hidden_states = image_hidden_states
|
||||
|
||||
if self.model.config.model_type == "mllama":
|
||||
batch.pixel_values = None
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, batch, (forward_ns, decode_ns)
|
||||
|
|
Loading…
Reference in New Issue