Working state ? (Broke idefics1 temporarily).
This commit is contained in:
parent
39d2073e93
commit
8abdd08ef4
|
@ -28,11 +28,17 @@ class ToolCall(BaseModel):
|
||||||
function: dict
|
function: dict
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk(BaseModel):
|
||||||
|
type: str
|
||||||
|
text: Optional[str] = None
|
||||||
|
image_url: Any = None
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
class Message(BaseModel):
|
||||||
# Role of the message sender
|
# Role of the message sender
|
||||||
role: str
|
role: str
|
||||||
# Content of the message
|
# Content of the message
|
||||||
content: Optional[str] = None
|
content: Optional[Union[str, List[Chunk]]] = None
|
||||||
# Optional name of the message sender
|
# Optional name of the message sender
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
# Tool calls associated with the chat completion
|
# Tool calls associated with the chat completion
|
||||||
|
|
|
@ -0,0 +1,106 @@
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727097740,
|
||||||
|
"id": "",
|
||||||
|
"model": "s0409/model-3",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"prompt_tokens": 24,
|
||||||
|
"total_tokens": 44
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727097740,
|
||||||
|
"id": "",
|
||||||
|
"model": "s0409/model-3",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"prompt_tokens": 24,
|
||||||
|
"total_tokens": 44
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727097740,
|
||||||
|
"id": "",
|
||||||
|
"model": "s0409/model-3",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"prompt_tokens": 24,
|
||||||
|
"total_tokens": 44
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727097740,
|
||||||
|
"id": "",
|
||||||
|
"model": "s0409/model-3",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"prompt_tokens": 24,
|
||||||
|
"total_tokens": 44
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
|
@ -0,0 +1,26 @@
|
||||||
|
{
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"message": {
|
||||||
|
"content": "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak",
|
||||||
|
"name": null,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": null
|
||||||
|
},
|
||||||
|
"usage": null
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1727090615,
|
||||||
|
"id": "",
|
||||||
|
"model": "s0409/model-3",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"system_fingerprint": "2.2.1-dev0-native",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"prompt_tokens": 24,
|
||||||
|
"total_tokens": 44
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,108 @@
|
||||||
|
import pytest
|
||||||
|
import base64
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def mllama_handle(launcher):
|
||||||
|
with launcher("s0409/model-3", num_shard=2) as handle:
|
||||||
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
async def mllama(mllama_handle):
|
||||||
|
await mllama_handle.health(300)
|
||||||
|
return mllama_handle.client
|
||||||
|
|
||||||
|
|
||||||
|
# TODO fix the server parsser to count inline image tokens correctly
|
||||||
|
def get_chicken():
|
||||||
|
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
def get_cow_beach():
|
||||||
|
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
|
||||||
|
encoded_string = base64.b64encode(image_file.read())
|
||||||
|
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mllama_simpl(mllama, response_snapshot):
|
||||||
|
# chicken = get_chicken()
|
||||||
|
response = await mllama.chat(
|
||||||
|
max_tokens=20,
|
||||||
|
temperature=0.0,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Can you tell me a very short story based on the image?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.usage == {
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"prompt_tokens": 24,
|
||||||
|
"total_tokens": 44,
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
response.choices[0].message.content
|
||||||
|
== "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak"
|
||||||
|
)
|
||||||
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.release
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mllama_load(mllama, generate_load, response_snapshot):
|
||||||
|
futures = [
|
||||||
|
mllama.chat(
|
||||||
|
max_tokens=20,
|
||||||
|
temperature=0.0,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Can you tell me a very short story based on the image?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://raw.githubusercontent.com/huggingface/text-generation-inference/main/integration-tests/images/chicken_on_money.png"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for i in range(4)
|
||||||
|
]
|
||||||
|
responses = await asyncio.gather(*futures)
|
||||||
|
|
||||||
|
generated_texts = [response.choices[0].message.content for response in responses]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
generated_texts[0]
|
||||||
|
== "In a small village, a rooster named Cluck Norris ruled the coop with an iron beak"
|
||||||
|
)
|
||||||
|
assert len(generated_texts) == 4
|
||||||
|
assert generated_texts, all(
|
||||||
|
[text == generated_texts[0] for text in generated_texts]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert responses == response_snapshot
|
|
@ -146,6 +146,7 @@ pub enum Config {
|
||||||
ClipVisionModel(ClipVisionModel),
|
ClipVisionModel(ClipVisionModel),
|
||||||
Mistral,
|
Mistral,
|
||||||
Idefics,
|
Idefics,
|
||||||
|
Mllama,
|
||||||
Idefics2(Idefics2),
|
Idefics2(Idefics2),
|
||||||
Ssm,
|
Ssm,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
|
|
|
@ -567,6 +567,7 @@ fn image_tokens(
|
||||||
use HubPreprocessorConfig::*;
|
use HubPreprocessorConfig::*;
|
||||||
match config {
|
match config {
|
||||||
Idefics => "<image>".to_string(),
|
Idefics => "<image>".to_string(),
|
||||||
|
Mllama => "<|image|>".to_string(),
|
||||||
Idefics2(config) => {
|
Idefics2(config) => {
|
||||||
const FAKE: &str = "<fake_token_around_image>";
|
const FAKE: &str = "<fake_token_around_image>";
|
||||||
const IMAGE: &str = "<image>";
|
const IMAGE: &str = "<image>";
|
||||||
|
@ -618,7 +619,7 @@ fn prepare_input(
|
||||||
use Config::*;
|
use Config::*;
|
||||||
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
static RE: Lazy<Regex> = Lazy::new(|| Regex::new(r"!\[\]\([^\)]*\)").unwrap());
|
||||||
let (tokenizer_query, input_chunks) = match config {
|
let (tokenizer_query, input_chunks) = match config {
|
||||||
Some(config @ (Idefics | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
Some(config @ (Idefics | Mllama | Idefics2(_) | Paligemma(_) | LlavaNext(_))) => {
|
||||||
let mut input_chunks = Vec::new();
|
let mut input_chunks = Vec::new();
|
||||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||||
let mut start = 0;
|
let mut start = 0;
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -98,6 +98,8 @@ class IDEFICSSharded(IdeficsCausalLM):
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(IdeficsCausalLM, self).__init__(
|
super(IdeficsCausalLM, self).__init__(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
|
|
@ -6,8 +6,6 @@ import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
|
||||||
AutoTokenizer,
|
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
ProcessorMixin,
|
ProcessorMixin,
|
||||||
)
|
)
|
||||||
|
@ -38,6 +36,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
attention_mask: torch.Tensor
|
attention_mask: torch.Tensor
|
||||||
position_ids: torch.Tensor
|
position_ids: torch.Tensor
|
||||||
pixel_values: Optional[torch.Tensor]
|
pixel_values: Optional[torch.Tensor]
|
||||||
|
aspect_ratio_ids: Optional[torch.Tensor]
|
||||||
|
aspect_ratio_mask: Optional[torch.Tensor]
|
||||||
|
cross_attention_mask: Optional[torch.Tensor]
|
||||||
image_hidden_states: Optional[torch.Tensor]
|
image_hidden_states: Optional[torch.Tensor]
|
||||||
image_attention_mask: Optional[torch.Tensor]
|
image_attention_mask: Optional[torch.Tensor]
|
||||||
past_key_values: Optional[List[Tuple]]
|
past_key_values: Optional[List[Tuple]]
|
||||||
|
@ -164,7 +165,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
image = Image.open(BytesIO(chunk.image.data))
|
image = Image.open(BytesIO(chunk.image.data))
|
||||||
curr_images.append(image)
|
curr_images.append(image)
|
||||||
# TODO unsure about BOS
|
# TODO unsure about BOS
|
||||||
curr_text += "<|image|><|begin_of_text|>"
|
curr_text += "<|image|>"
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
raise RuntimeError(f"Invalid chunk type {chunk_type}")
|
||||||
images.append(curr_images)
|
images.append(curr_images)
|
||||||
|
@ -173,6 +174,8 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
# The processor replaces the call to tokenizer, and
|
# The processor replaces the call to tokenizer, and
|
||||||
# a/ takes care of fetching images from the URL
|
# a/ takes care of fetching images from the URL
|
||||||
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
|
# b/ generate the correct input_ids, attention_mask, pixel_values, image_attention_mask to feed to the model
|
||||||
|
if all(len(im) == 0 for im in images):
|
||||||
|
images = None
|
||||||
tokenized_inputs = processor(
|
tokenized_inputs = processor(
|
||||||
images=images,
|
images=images,
|
||||||
text=texts,
|
text=texts,
|
||||||
|
@ -205,7 +208,10 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
# Do the same for image_attention_mask
|
# Do the same for image_attention_mask
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
image_attention_mask = None
|
image_attention_mask = None
|
||||||
else:
|
aspect_ratio_ids = None
|
||||||
|
aspect_ratio_mask = None
|
||||||
|
cross_attention_mask = None
|
||||||
|
elif "image_attention_mask" in tokenized_inputs:
|
||||||
image_attention_mask = input_ids.new_zeros(
|
image_attention_mask = input_ids.new_zeros(
|
||||||
(
|
(
|
||||||
pb.size,
|
pb.size,
|
||||||
|
@ -216,6 +222,19 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||||
"image_attention_mask"
|
"image_attention_mask"
|
||||||
]
|
]
|
||||||
|
aspect_ratio_ids = None
|
||||||
|
aspect_ratio_mask = None
|
||||||
|
cross_attention_mask = None
|
||||||
|
else:
|
||||||
|
image_attention_mask = None
|
||||||
|
aspect_ratio_ids = tokenized_inputs["aspect_ratio_ids"]
|
||||||
|
aspect_ratio_mask = tokenized_inputs["aspect_ratio_mask"]
|
||||||
|
cross_attention_mask = tokenized_inputs["cross_attention_mask"]
|
||||||
|
pixel_values = pixel_values.to(dtype=dtype)
|
||||||
|
# XXX: <|image|> token is actually out of bounds and bugs out the logit processors.
|
||||||
|
tokenized_inputs["input_ids"] = tokenized_inputs["input_ids"].clamp(
|
||||||
|
max=processor.tokenizer.vocab_size - 1
|
||||||
|
)
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
|
@ -245,6 +264,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
max_input_length=max_input_length.item(),
|
max_input_length=max_input_length.item(),
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
aspect_ratio_ids=aspect_ratio_ids,
|
||||||
|
aspect_ratio_mask=aspect_ratio_mask,
|
||||||
|
cross_attention_mask=cross_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("filter")
|
@tracer.start_as_current_span("filter")
|
||||||
|
@ -308,7 +330,12 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
+ new_padding_right_offset,
|
+ new_padding_right_offset,
|
||||||
]
|
]
|
||||||
# Do the same for pixel_values and image_attention_mask
|
# Do the same for pixel_values and image_attention_mask
|
||||||
|
if self.pixel_values is not None:
|
||||||
pixel_values = self.pixel_values[keep_indices]
|
pixel_values = self.pixel_values[keep_indices]
|
||||||
|
else:
|
||||||
|
pixel_values = None
|
||||||
|
|
||||||
|
if self.image_attention_mask is not None:
|
||||||
self.image_attention_mask = self.image_attention_mask[
|
self.image_attention_mask = self.image_attention_mask[
|
||||||
keep_indices,
|
keep_indices,
|
||||||
-(self.padding_right_offset + max_input_length) : (
|
-(self.padding_right_offset + max_input_length) : (
|
||||||
|
@ -317,6 +344,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
+ new_padding_right_offset,
|
+ new_padding_right_offset,
|
||||||
:,
|
:,
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.image_hidden_states is None:
|
if self.image_hidden_states is None:
|
||||||
image_hidden_states = None
|
image_hidden_states = None
|
||||||
else:
|
else:
|
||||||
|
@ -359,6 +387,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
self.max_input_length = max_input_length
|
self.max_input_length = max_input_length
|
||||||
self.padding_right_offset = new_padding_right_offset
|
self.padding_right_offset = new_padding_right_offset
|
||||||
self.max_tokens = max_tokens
|
self.max_tokens = max_tokens
|
||||||
|
self.aspect_ratio_ids = None
|
||||||
|
self.aspect_ratio_mask = None
|
||||||
|
self.cross_attention_mask = None
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -376,6 +407,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
total_batch_size += len(batch)
|
total_batch_size += len(batch)
|
||||||
max_input_length = max(max_input_length, batch.max_input_length)
|
max_input_length = max(max_input_length, batch.max_input_length)
|
||||||
|
if batch.pixel_values is not None:
|
||||||
max_num_images = max(max_num_images, batch.pixel_values.size(1))
|
max_num_images = max(max_num_images, batch.pixel_values.size(1))
|
||||||
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
padding_right_offset = max(padding_right_offset, batch.padding_right_offset)
|
||||||
|
|
||||||
|
@ -439,6 +471,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
(total_batch_size, max_input_length + padding_right_offset),
|
(total_batch_size, max_input_length + padding_right_offset),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if batch.pixel_values is not None:
|
||||||
curr_batch_max_num_images = batch.pixel_values.size(1)
|
curr_batch_max_num_images = batch.pixel_values.size(1)
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
pixel_values = batch.pixel_values.new_zeros(
|
pixel_values = batch.pixel_values.new_zeros(
|
||||||
|
@ -447,8 +480,10 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
|
pixel_values[start_index:end_index, :curr_batch_max_num_images] = (
|
||||||
batch.pixel_values
|
batch.pixel_values
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
pixel_values = None
|
||||||
|
|
||||||
if image_attention_mask is None:
|
if image_attention_mask is None and batch.image_attention_mask is not None:
|
||||||
image_attention_mask = batch.image_attention_mask.new_zeros(
|
image_attention_mask = batch.image_attention_mask.new_zeros(
|
||||||
(
|
(
|
||||||
total_batch_size,
|
total_batch_size,
|
||||||
|
@ -472,6 +507,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
:,
|
:,
|
||||||
batch_left_offset : -batch.padding_right_offset,
|
batch_left_offset : -batch.padding_right_offset,
|
||||||
]
|
]
|
||||||
|
if batch.image_attention_mask is not None:
|
||||||
image_attention_mask[
|
image_attention_mask[
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
left_offset:-padding_right_offset,
|
left_offset:-padding_right_offset,
|
||||||
|
@ -531,7 +567,20 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
# Iterate over attention layers
|
# Iterate over attention layers
|
||||||
# Concatenate past key values layer by layer to allow incremental garbage collection
|
# Concatenate past key values layer by layer to allow incremental garbage collection
|
||||||
for j in range(len(first_past_kvs)):
|
for j in range(len(first_past_kvs)):
|
||||||
padded_past_keys = first_past_kvs[j][0].new_zeros(padded_past_keys_shape)
|
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][0].shape
|
||||||
|
if seqlen > max_input_length:
|
||||||
|
# XXX: This is probably a cross attention key value
|
||||||
|
# If not this is ok
|
||||||
|
_padded_past_keys_shape = (
|
||||||
|
total_batch_size,
|
||||||
|
_num_heads,
|
||||||
|
seqlen,
|
||||||
|
_head_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_padded_past_keys_shape = padded_past_keys_shape
|
||||||
|
|
||||||
|
padded_past_keys = first_past_kvs[j][0].new_zeros(_padded_past_keys_shape)
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
past_keys = batch.past_key_values[j][0]
|
past_keys = batch.past_key_values[j][0]
|
||||||
|
@ -542,6 +591,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
# We slice the keys to remove the padding from previous batches
|
# We slice the keys to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
|
if past_keys.shape[2] > past_seq_len:
|
||||||
|
# XXX: This is a cross attention kv in mllama
|
||||||
|
past_seq_len = past_keys.shape[2]
|
||||||
if batch.keys_head_dim_last:
|
if batch.keys_head_dim_last:
|
||||||
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
padded_past_keys[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
past_keys[:, :, -past_seq_len:, :]
|
past_keys[:, :, -past_seq_len:, :]
|
||||||
|
@ -555,8 +607,20 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
|
|
||||||
start_index = end_index
|
start_index = end_index
|
||||||
|
|
||||||
|
_, _num_heads, seqlen, _head_dim = first_past_kvs[j][1].shape
|
||||||
|
if seqlen > max_input_length:
|
||||||
|
# XXX: This is probably a cross attention key value
|
||||||
|
# If not this is ok
|
||||||
|
_padded_past_values_shape = (
|
||||||
|
total_batch_size,
|
||||||
|
_num_heads,
|
||||||
|
seqlen,
|
||||||
|
_head_dim,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_padded_past_values_shape = padded_past_values_shape
|
||||||
padded_past_values = first_past_kvs[j][1].new_zeros(
|
padded_past_values = first_past_kvs[j][1].new_zeros(
|
||||||
padded_past_values_shape
|
_padded_past_values_shape
|
||||||
)
|
)
|
||||||
start_index = 0
|
start_index = 0
|
||||||
for batch in batches:
|
for batch in batches:
|
||||||
|
@ -568,6 +632,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
# We slice the past values to remove the padding from previous batches
|
# We slice the past values to remove the padding from previous batches
|
||||||
past_seq_len = batch.max_input_length - 1
|
past_seq_len = batch.max_input_length - 1
|
||||||
|
if past_values.shape[2] > past_seq_len:
|
||||||
|
# XXX: This is a cross attention kv in mllama
|
||||||
|
past_seq_len = past_values.shape[2]
|
||||||
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
padded_past_values[start_index:end_index, :, -past_seq_len:, :] = (
|
||||||
past_values[:, :, -past_seq_len:, :]
|
past_values[:, :, -past_seq_len:, :]
|
||||||
)
|
)
|
||||||
|
@ -599,6 +666,10 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
padding_right_offset=padding_right_offset,
|
padding_right_offset=padding_right_offset,
|
||||||
keys_head_dim_last=batches[0].keys_head_dim_last,
|
keys_head_dim_last=batches[0].keys_head_dim_last,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
# No need to keep this around. for Mllamma
|
||||||
|
aspect_ratio_ids=None,
|
||||||
|
aspect_ratio_mask=None,
|
||||||
|
cross_attention_mask=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -606,77 +677,6 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
|
|
||||||
|
|
||||||
class IdeficsCausalLM(Model):
|
class IdeficsCausalLM(Model):
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
quantize: Optional[str] = None,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
):
|
|
||||||
self.quantize = quantize
|
|
||||||
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
|
||||||
IdeficsForVisionText2Text,
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = torch.device("cuda")
|
|
||||||
dtype = torch.bfloat16 if dtype is None else dtype
|
|
||||||
else:
|
|
||||||
if quantize:
|
|
||||||
raise ValueError("quantization is not available on CPU")
|
|
||||||
|
|
||||||
device = torch.device("cpu")
|
|
||||||
dtype = torch.float32 if dtype is None else dtype
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
self.processor = AutoProcessor.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
padding_side="left",
|
|
||||||
truncation_side="left",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
model = IdeficsForVisionText2Text.from_pretrained(
|
|
||||||
model_id,
|
|
||||||
revision=revision,
|
|
||||||
torch_dtype=dtype,
|
|
||||||
device_map=(
|
|
||||||
"auto"
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() > 1
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
load_in_8bit=quantize == "bitsandbytes",
|
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
|
||||||
if torch.cuda.is_available() and torch.cuda.device_count() == 1:
|
|
||||||
model = model.cuda()
|
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
|
||||||
if model.config.pad_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = model.config.pad_token_id
|
|
||||||
elif model.config.eos_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = model.config.eos_token_id
|
|
||||||
elif tokenizer.eos_token_id is not None:
|
|
||||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
|
||||||
else:
|
|
||||||
tokenizer.add_special_tokens({"pad_token": "<unk>"})
|
|
||||||
|
|
||||||
super(IdeficsCausalLM, self).__init__(
|
|
||||||
model_id=model_id,
|
|
||||||
model=model,
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
requires_padding=True,
|
|
||||||
dtype=dtype,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
|
def batch_type(self) -> Type[IdeficsCausalLMBatch]:
|
||||||
return IdeficsCausalLMBatch
|
return IdeficsCausalLMBatch
|
||||||
|
@ -690,6 +690,9 @@ class IdeficsCausalLM(Model):
|
||||||
image_hidden_states,
|
image_hidden_states,
|
||||||
image_attention_mask,
|
image_attention_mask,
|
||||||
past_key_values: Optional = None,
|
past_key_values: Optional = None,
|
||||||
|
aspect_ratio_ids=None,
|
||||||
|
aspect_ratio_mask=None,
|
||||||
|
cross_attention_mask=None,
|
||||||
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||||
# Model Forward
|
# Model Forward
|
||||||
kwargs = {
|
kwargs = {
|
||||||
|
@ -699,18 +702,23 @@ class IdeficsCausalLM(Model):
|
||||||
"image_hidden_states": image_hidden_states,
|
"image_hidden_states": image_hidden_states,
|
||||||
"image_attention_mask": image_attention_mask,
|
"image_attention_mask": image_attention_mask,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"use_cache": True,
|
|
||||||
"return_dict": True,
|
|
||||||
}
|
}
|
||||||
if self.has_position_ids:
|
if self.has_position_ids:
|
||||||
kwargs["position_ids"] = position_ids
|
kwargs["position_ids"] = position_ids
|
||||||
|
if aspect_ratio_ids is not None:
|
||||||
|
kwargs["aspect_ratio_ids"] = aspect_ratio_ids
|
||||||
|
if aspect_ratio_mask is not None:
|
||||||
|
kwargs["aspect_ratio_mask"] = aspect_ratio_mask
|
||||||
|
if cross_attention_mask is not None:
|
||||||
|
kwargs["cross_attention_mask"] = cross_attention_mask
|
||||||
|
|
||||||
outputs, speculative_logits = self.model.forward(**kwargs)
|
outputs, speculative_logits = self.model.forward(**kwargs)
|
||||||
|
assert outputs.past_key_values is not None
|
||||||
return (
|
return (
|
||||||
outputs.logits,
|
outputs.logits,
|
||||||
speculative_logits,
|
speculative_logits,
|
||||||
outputs.past_key_values,
|
outputs.past_key_values,
|
||||||
outputs.image_hidden_states,
|
getattr(outputs, "image_hidden_states", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
@tracer.start_as_current_span("generate_token")
|
@tracer.start_as_current_span("generate_token")
|
||||||
|
@ -745,8 +753,12 @@ class IdeficsCausalLM(Model):
|
||||||
image_hidden_states=batch.image_hidden_states,
|
image_hidden_states=batch.image_hidden_states,
|
||||||
image_attention_mask=image_attention_mask,
|
image_attention_mask=image_attention_mask,
|
||||||
past_key_values=batch.past_key_values,
|
past_key_values=batch.past_key_values,
|
||||||
|
aspect_ratio_ids=batch.aspect_ratio_ids,
|
||||||
|
aspect_ratio_mask=batch.aspect_ratio_mask,
|
||||||
|
cross_attention_mask=batch.cross_attention_mask,
|
||||||
)
|
)
|
||||||
# Hardcoded remove image tokens
|
# Hardcoded remove image tokens
|
||||||
|
if self.config.model_type == "idefics":
|
||||||
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
logits[:, 32000:32001] = torch.finfo(logits.dtype).min
|
||||||
|
|
||||||
start_decode = time.time_ns()
|
start_decode = time.time_ns()
|
||||||
|
@ -890,10 +902,13 @@ class IdeficsCausalLM(Model):
|
||||||
batch.input_ids = batch.input_ids[:, :1]
|
batch.input_ids = batch.input_ids[:, :1]
|
||||||
|
|
||||||
# Update attention_mask as we added a new token to input_ids
|
# Update attention_mask as we added a new token to input_ids
|
||||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
# batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
|
if batch.image_attention_mask is not None:
|
||||||
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
|
batch.image_attention_mask[:, -batch.padding_right_offset, :] = (
|
||||||
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
|
batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
|
||||||
)
|
)
|
||||||
|
if batch.cross_attention_mask is not None:
|
||||||
|
batch.cross_attention_mask = batch.cross_attention_mask[:, -1:]
|
||||||
# Decrease right offset
|
# Decrease right offset
|
||||||
batch.padding_right_offset -= 1
|
batch.padding_right_offset -= 1
|
||||||
|
|
||||||
|
@ -903,7 +918,8 @@ class IdeficsCausalLM(Model):
|
||||||
# Update past key values
|
# Update past key values
|
||||||
batch.past_key_values = past
|
batch.past_key_values = past
|
||||||
batch.image_hidden_states = image_hidden_states
|
batch.image_hidden_states = image_hidden_states
|
||||||
|
if self.model.config.model_type == "mllama":
|
||||||
|
batch.pixel_values = None
|
||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, batch, (forward_ns, decode_ns)
|
return generations, batch, (forward_ns, decode_ns)
|
||||||
|
|
Loading…
Reference in New Issue