Idefics2. (#1756)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
This commit is contained in:
parent
986b4044d1
commit
bfddfa5955
|
@ -1,5 +1,6 @@
|
|||
## Speculation
|
||||
|
||||
|
||||
Speculative decoding, assisted generation, Medusa, and others are a few different names for the same idea.
|
||||
The idea is to generate tokens *before* the large model actually runs, and only *check* if those tokens where valid.
|
||||
|
||||
|
|
|
@ -293,6 +293,7 @@ def launcher(event_loop):
|
|||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
@ -334,6 +335,9 @@ def launcher(event_loop):
|
|||
if max_input_length:
|
||||
args.append("--max-input-length")
|
||||
args.append(str(max_input_length))
|
||||
if max_batch_prefill_tokens:
|
||||
args.append("--max-batch-prefill-tokens")
|
||||
args.append(str(max_batch_prefill_tokens))
|
||||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
|
@ -371,6 +375,7 @@ def launcher(event_loop):
|
|||
dtype: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
max_batch_prefill_tokens: Optional[int] = None,
|
||||
max_total_tokens: Optional[int] = None,
|
||||
):
|
||||
port = random.randint(8000, 10_000)
|
||||
|
@ -395,6 +400,9 @@ def launcher(event_loop):
|
|||
if max_input_length:
|
||||
args.append("--max-input-length")
|
||||
args.append(str(max_input_length))
|
||||
if max_batch_prefill_tokens:
|
||||
args.append("--max-batch-prefill-tokens")
|
||||
args.append(str(max_batch_prefill_tokens))
|
||||
if max_total_tokens:
|
||||
args.append("--max-total-tokens")
|
||||
args.append(str(max_total_tokens))
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [
|
||||
{
|
||||
"id": 1,
|
||||
"logprob": null,
|
||||
"text": "<s>"
|
||||
},
|
||||
{
|
||||
"id": 3735,
|
||||
"logprob": -8.5625,
|
||||
"text": "Test"
|
||||
},
|
||||
{
|
||||
"id": 2159,
|
||||
"logprob": -10.78125,
|
||||
"text": "request"
|
||||
}
|
||||
],
|
||||
"seed": 0,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 288,
|
||||
"logprob": -0.2854004,
|
||||
"special": false,
|
||||
"text": "ing"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.37573242,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 633,
|
||||
"logprob": -0.09301758,
|
||||
"special": false,
|
||||
"text": " new"
|
||||
},
|
||||
{
|
||||
"id": 4480,
|
||||
"logprob": -0.3322754,
|
||||
"special": false,
|
||||
"text": " feature"
|
||||
},
|
||||
{
|
||||
"id": 297,
|
||||
"logprob": -0.8510742,
|
||||
"special": false,
|
||||
"text": " in"
|
||||
},
|
||||
{
|
||||
"id": 272,
|
||||
"logprob": -0.13464355,
|
||||
"special": false,
|
||||
"text": " the"
|
||||
},
|
||||
{
|
||||
"id": 2039,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": " game"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.89990234,
|
||||
"special": false,
|
||||
"text": "."
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
},
|
||||
{
|
||||
"id": 13,
|
||||
"logprob": 0.0,
|
||||
"special": false,
|
||||
"text": "\n"
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": "Test requesting a new feature in the game.\n\n"
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,73 @@
|
|||
{
|
||||
"details": {
|
||||
"best_of_sequences": null,
|
||||
"finish_reason": "length",
|
||||
"generated_tokens": 10,
|
||||
"prefill": [],
|
||||
"seed": null,
|
||||
"tokens": [
|
||||
{
|
||||
"id": 330,
|
||||
"logprob": -0.13000488,
|
||||
"special": false,
|
||||
"text": " A"
|
||||
},
|
||||
{
|
||||
"id": 13088,
|
||||
"logprob": -0.6713867,
|
||||
"special": false,
|
||||
"text": " chicken"
|
||||
},
|
||||
{
|
||||
"id": 349,
|
||||
"logprob": -0.2980957,
|
||||
"special": false,
|
||||
"text": " is"
|
||||
},
|
||||
{
|
||||
"id": 6398,
|
||||
"logprob": -0.060638428,
|
||||
"special": false,
|
||||
"text": " sitting"
|
||||
},
|
||||
{
|
||||
"id": 356,
|
||||
"logprob": -0.27319336,
|
||||
"special": false,
|
||||
"text": " on"
|
||||
},
|
||||
{
|
||||
"id": 264,
|
||||
"logprob": -0.140625,
|
||||
"special": false,
|
||||
"text": " a"
|
||||
},
|
||||
{
|
||||
"id": 17972,
|
||||
"logprob": -0.040405273,
|
||||
"special": false,
|
||||
"text": " pile"
|
||||
},
|
||||
{
|
||||
"id": 302,
|
||||
"logprob": -0.0002708435,
|
||||
"special": false,
|
||||
"text": " of"
|
||||
},
|
||||
{
|
||||
"id": 2445,
|
||||
"logprob": -0.095336914,
|
||||
"special": false,
|
||||
"text": " money"
|
||||
},
|
||||
{
|
||||
"id": 28723,
|
||||
"logprob": -0.0068359375,
|
||||
"special": false,
|
||||
"text": "."
|
||||
}
|
||||
],
|
||||
"top_tokens": null
|
||||
},
|
||||
"generated_text": " A chicken is sitting on a pile of money."
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,81 @@
|
|||
import pytest
|
||||
import base64
|
||||
|
||||
|
||||
# TODO fix the server parsser to count inline image tokens correctly
|
||||
def get_chicken():
|
||||
with open("integration-tests/images/chicken_on_money.png", "rb") as image_file:
|
||||
encoded_string = base64.b64encode(image_file.read())
|
||||
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def flash_idefics2_next_handle(launcher):
|
||||
with launcher(
|
||||
"HuggingFaceM4/idefics2-8b",
|
||||
) as handle:
|
||||
yield handle
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def flash_idefics2_next(flash_idefics2_next_handle):
|
||||
await flash_idefics2_next_handle.health(300)
|
||||
return flash_idefics2_next_handle.client
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_idefics2_next_simple(flash_idefics2_next, response_snapshot):
|
||||
chicken = get_chicken()
|
||||
response = await flash_idefics2_next.generate(
|
||||
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
||||
max_new_tokens=10,
|
||||
)
|
||||
assert (
|
||||
response.generated_text == " A chicken is sitting on a pile of money."
|
||||
), f"{repr(response.generated_text)}"
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_idefics2_next_all_params(flash_idefics2_next, response_snapshot):
|
||||
response = await flash_idefics2_next.generate(
|
||||
"Test request",
|
||||
max_new_tokens=10,
|
||||
repetition_penalty=1.2,
|
||||
return_full_text=True,
|
||||
stop_sequences=["test"],
|
||||
temperature=0.5,
|
||||
top_p=0.9,
|
||||
top_k=10,
|
||||
truncate=5,
|
||||
typical_p=0.9,
|
||||
watermark=True,
|
||||
decoder_input_details=True,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert response.details.generated_tokens == 10
|
||||
assert response == response_snapshot
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_flash_idefics2_next_load(
|
||||
flash_idefics2_next, generate_load, response_snapshot
|
||||
):
|
||||
chicken = get_chicken()
|
||||
responses = await generate_load(
|
||||
flash_idefics2_next,
|
||||
f"User:![]({chicken})Write me a short story<end_of_utterance> \nAssistant:",
|
||||
max_new_tokens=10,
|
||||
n=4,
|
||||
)
|
||||
generated_texts = [r.generated_text for r in responses]
|
||||
assert generated_texts[0] == " A chicken is sitting on a pile of money."
|
||||
assert len(generated_texts) == 4
|
||||
assert all([r.generated_text == generated_texts[0] for r in responses])
|
||||
|
||||
assert responses == response_snapshot
|
|
@ -114,8 +114,12 @@ impl Client {
|
|||
let truncate = min(max_input_length, max_prefill_tokens - n_tokens);
|
||||
|
||||
let mut inputs = String::new();
|
||||
inputs.push_str("![](");
|
||||
inputs.push_str(&"_test ".to_string().repeat(max_input_length as usize));
|
||||
if n_tokens == 0 {
|
||||
// 1 request is enough to test vision heads.
|
||||
// Sending images on other queries messes up easily with truncation.
|
||||
inputs.push_str("![]()");
|
||||
}
|
||||
|
||||
requests.push(Request {
|
||||
id: 0,
|
||||
|
|
|
@ -57,6 +57,31 @@ fn select_best_resolution(
|
|||
best_fit.unwrap_or((original_height, original_width))
|
||||
}
|
||||
|
||||
fn get_unpadded_features(
|
||||
height: usize,
|
||||
width: usize,
|
||||
npatches: usize,
|
||||
num_patch_height: usize,
|
||||
num_patch_width: usize,
|
||||
) -> (usize, usize) {
|
||||
let current_height = npatches * num_patch_height;
|
||||
let current_width = npatches * num_patch_width;
|
||||
|
||||
let aspect_ratio: f64 = width as f64 / height as f64;
|
||||
let current_aspect_ratio: f64 = current_width as f64 / current_height as f64;
|
||||
let (current_height, current_width) = if aspect_ratio > current_aspect_ratio {
|
||||
let new_height = (height * current_width) / width;
|
||||
(new_height, current_width)
|
||||
} else {
|
||||
let new_width = (width * current_height) / height;
|
||||
(current_height, new_width)
|
||||
};
|
||||
|
||||
let unpadded_features = current_height * current_width;
|
||||
let newline_features = current_height;
|
||||
(unpadded_features, newline_features)
|
||||
}
|
||||
|
||||
impl LlavaNext {
|
||||
pub fn get_number_of_features(&self, height: usize, width: usize) -> usize {
|
||||
let image_size = self.vision_config.image_size;
|
||||
|
@ -65,11 +90,9 @@ impl LlavaNext {
|
|||
let npatches = image_size / patch_size;
|
||||
let (num_patch_height, num_patch_width) =
|
||||
get_anyres_image_grid_shape(height, width, &self.image_grid_pinpoints, image_size);
|
||||
// Ceil
|
||||
let height_of_patch = (height * npatches + width - 1) / width;
|
||||
let unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width;
|
||||
// They are only added after width
|
||||
let newline_features = height_of_patch * num_patch_width;
|
||||
|
||||
let (unpadded_features, newline_features) =
|
||||
get_unpadded_features(height, width, npatches, num_patch_height, num_patch_width);
|
||||
// The base patch covers the entire image
|
||||
let base_features = npatches.pow(2);
|
||||
unpadded_features + newline_features + base_features
|
||||
|
@ -84,6 +107,17 @@ pub struct ClipVisionModel {
|
|||
patch_size: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub struct Idefics2 {}
|
||||
|
||||
impl Idefics2 {
|
||||
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
|
||||
320
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "model_type")]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
|
@ -92,6 +126,7 @@ pub enum Config {
|
|||
ClipVisionModel(ClipVisionModel),
|
||||
Mistral,
|
||||
Idefics,
|
||||
Idefics2(Idefics2),
|
||||
Ssm,
|
||||
GptBigcode,
|
||||
Santacoder,
|
||||
|
@ -146,13 +181,17 @@ mod test {
|
|||
],
|
||||
};
|
||||
|
||||
let slots = config.get_number_of_features(20, 20);
|
||||
assert_eq!(slots, 1176);
|
||||
let slots = config.get_number_of_features(640, 640);
|
||||
assert_eq!(slots, 2928);
|
||||
let slots = config.get_number_of_features(480, 640);
|
||||
assert_eq!(slots, 2340);
|
||||
let slots = config.get_number_of_features(899, 1024);
|
||||
assert_eq!(slots, 2732);
|
||||
assert_eq!(slots, 2634);
|
||||
let slots = config.get_number_of_features(1024, 899);
|
||||
assert_eq!(slots, 3320);
|
||||
assert_eq!(slots, 2640);
|
||||
let slots = config.get_number_of_features(1067, 1600);
|
||||
assert_eq!(slots, 2144);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -540,7 +540,57 @@ fn prepare_input(
|
|||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
}
|
||||
Some(Config::Idefics) => RE.replace_all(&inputs, "<image>").into(),
|
||||
Some(Config::Idefics2(config)) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = config.get_number_of_features(height, width);
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
tokenizer_query.push_str("<fake_token_around_image>");
|
||||
|
||||
modified_inputs.push_str(&image_uri);
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
}
|
||||
Some(Config::Idefics) => {
|
||||
let mut modified_inputs = String::with_capacity(inputs.len());
|
||||
let mut tokenizer_query = String::with_capacity(inputs.len());
|
||||
let mut start = 0;
|
||||
for chunk in RE.find_iter(&inputs) {
|
||||
let chunk_start = chunk.start();
|
||||
let chunk_end = chunk.end();
|
||||
if chunk_start != start {
|
||||
modified_inputs.push_str(&inputs[start..chunk_start]);
|
||||
tokenizer_query.push_str(&inputs[start..chunk_start]);
|
||||
}
|
||||
let (image_uri, _height, _width) = fetch_image(&inputs[chunk_start..chunk_end])?;
|
||||
let slots = 1;
|
||||
tokenizer_query.push_str(&"<image>".repeat(slots));
|
||||
modified_inputs.push_str(&image_uri);
|
||||
start = chunk_end;
|
||||
}
|
||||
if start != inputs.len() - 1 {
|
||||
modified_inputs.push_str(&inputs[start..]);
|
||||
tokenizer_query.push_str(&inputs[start..]);
|
||||
}
|
||||
inputs = modified_inputs;
|
||||
tokenizer_query
|
||||
}
|
||||
_ => inputs.clone(),
|
||||
};
|
||||
|
||||
|
|
|
@ -68,6 +68,7 @@ try:
|
|||
)
|
||||
from text_generation_server.models.idefics import IDEFICSSharded
|
||||
from text_generation_server.models.llava_next import LlavaNext
|
||||
from text_generation_server.models.idefics2 import Idefics2
|
||||
from text_generation_server.models.flash_mistral import FlashMistral
|
||||
from text_generation_server.models.flash_mixtral import FlashMixtral
|
||||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
|
@ -579,6 +580,18 @@ def get_model(
|
|||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
if model_type == "idefics2":
|
||||
if FLASH_ATTENTION:
|
||||
return Idefics2(
|
||||
model_id,
|
||||
revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||
|
||||
if model_type == "llava_next":
|
||||
if FLASH_ATTENTION:
|
||||
|
|
|
@ -409,23 +409,29 @@ class MistralModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashMistralForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, name=None):
|
||||
if name is None:
|
||||
name = "model"
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
prefix=(
|
||||
"model.embed_tokens" if not prefix else f"{prefix}.model.embed_tokens"
|
||||
f"{name}.embed_tokens"
|
||||
if not prefix
|
||||
else f"{prefix}.{name}.embed_tokens"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = MistralModel(
|
||||
prefix="model" if not prefix else f"{prefix}.model",
|
||||
prefix=name if not prefix else f"{prefix}.{name}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.lm_head = SpeculativeHead.load(
|
||||
config,
|
||||
prefix="lm_head" if not prefix else f"{prefix}.lm_head",
|
||||
# TODO dirty hack for idefics2.
|
||||
prefix=(
|
||||
"lm_head" if not prefix or name != "model" else f"{prefix}.lm_head"
|
||||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.max_past = config.sliding_window
|
||||
|
|
|
@ -0,0 +1,829 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" PyTorch Idefics2 model."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
class Idefics2VisionEmbeddings(nn.Module):
|
||||
"""
|
||||
This is a modified version of `siglip.modelign_siglip.SiglipVisionEmbeddings` to enable images of variable
|
||||
resolution.
|
||||
|
||||
The modifications are adapted from [Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution](https://arxiv.org/abs/2307.06304)
|
||||
which allows treating images in their native aspect ratio and without the need to resize them to the same
|
||||
fixed size. In particular, we start from the original pre-trained SigLIP model
|
||||
(which uses images of fixed-size square images) and adapt it by training on images of variable resolutions.
|
||||
"""
|
||||
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=config.num_channels,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
padding="valid",
|
||||
)
|
||||
self.patch_embedding.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
|
||||
)
|
||||
self.patch_embedding.bias = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.patch_embedding.bias"), requires_grad=False
|
||||
)
|
||||
|
||||
self.num_patches_per_side = self.image_size // self.patch_size
|
||||
self.num_patches = self.num_patches_per_side**2
|
||||
self.num_positions = self.num_patches
|
||||
self.position_embedding = TensorParallelEmbedding(
|
||||
prefix=f"{prefix}.position_embedding", weights=weights
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
|
||||
) -> torch.Tensor:
|
||||
batch_size, _, max_im_h, max_im_w = pixel_values.shape
|
||||
|
||||
patch_embeds = self.patch_embedding(pixel_values)
|
||||
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
max_nb_patches_h, max_nb_patches_w = (
|
||||
max_im_h // self.patch_size,
|
||||
max_im_w // self.patch_size,
|
||||
)
|
||||
boundaries = torch.arange(
|
||||
1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side
|
||||
)
|
||||
position_ids = torch.full(
|
||||
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
|
||||
)
|
||||
|
||||
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
|
||||
nb_patches_h = p_attn_mask[:, 0].sum()
|
||||
nb_patches_w = p_attn_mask[0].sum()
|
||||
|
||||
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
|
||||
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
|
||||
|
||||
bucket_coords_h = torch.bucketize(
|
||||
fractional_coords_h, boundaries, right=True
|
||||
)
|
||||
bucket_coords_w = torch.bucketize(
|
||||
fractional_coords_w, boundaries, right=True
|
||||
)
|
||||
|
||||
pos_ids = (
|
||||
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
|
||||
).flatten()
|
||||
position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids
|
||||
|
||||
position_ids = position_ids.to(self.position_embedding.weight.device)
|
||||
embeddings = embeddings + self.position_embedding(position_ids)
|
||||
return embeddings
|
||||
|
||||
|
||||
class Idefics2VisionAttention(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_size = self.embed_dim // self.num_heads
|
||||
if self.head_size * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_size**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||
|
||||
self.qkv = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=True,
|
||||
)
|
||||
self.out_proj = TensorParallelRowLinear.load(
|
||||
config=config, prefix=f"{prefix}.out_proj", weights=weights, bias=True
|
||||
)
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, q_len, _ = hidden_states.size()
|
||||
|
||||
qkv = self.qkv(hidden_states)
|
||||
query_states, key_states, value_states = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
self.head_size * self.num_heads,
|
||||
self.head_size * self.num_heads,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
query_states = query_states.view(
|
||||
batch_size, q_len, self.num_heads, self.head_size
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
batch_size, q_len, self.num_heads, self.head_size
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
batch_size, q_len, self.num_heads, self.head_size
|
||||
).transpose(1, 2)
|
||||
|
||||
k_v_seq_len = key_states.shape[-2]
|
||||
attn_weights = (
|
||||
torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
|
||||
)
|
||||
|
||||
if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_size):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_size)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
|
||||
|
||||
attn_output = self.out_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Idefics2VisionMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = ACT2FN[config.hidden_act]
|
||||
self.fc1 = TensorParallelColumnLinear.load(
|
||||
prefix=f"{prefix}.fc1", config=config, weights=weights, bias=True
|
||||
)
|
||||
self.fc2 = TensorParallelRowLinear.load(
|
||||
prefix=f"{prefix}.fc2", config=config, weights=weights, bias=True
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Idefics2EncoderLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = Idefics2VisionAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.layer_norm1 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm1", eps=config.layer_norm_eps, weights=weights
|
||||
)
|
||||
self.layer_norm2 = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.layer_norm2", eps=config.layer_norm_eps, weights=weights
|
||||
)
|
||||
self.mlp = Idefics2VisionMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||
)
|
||||
|
||||
# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoderLayer.forward
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Idefics2Encoder(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Idefics2EncoderLayer(
|
||||
prefix=f"{prefix}.layers.{i}", config=config, weights=weights
|
||||
)
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Idefics2VisionTransformer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embeddings = Idefics2VisionEmbeddings(
|
||||
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||
)
|
||||
self.encoder = Idefics2Encoder(
|
||||
prefix=f"{prefix}.encoder", config=config, weights=weights
|
||||
)
|
||||
self.post_layernorm = nn.LayerNorm.load(
|
||||
prefix=f"{prefix}.post_layernorm",
|
||||
weights=weights,
|
||||
eps=config.layer_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values,
|
||||
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
):
|
||||
batch_size = pixel_values.size(0)
|
||||
if patch_attention_mask is None:
|
||||
patch_size = self.config.patch_size
|
||||
patch_attention_mask = torch.ones(
|
||||
(
|
||||
batch_size,
|
||||
pixel_values.size(2) // patch_size,
|
||||
pixel_values.size(3) // patch_size,
|
||||
)
|
||||
)
|
||||
patch_attention_mask = patch_attention_mask.to(
|
||||
dtype=torch.bool, device=pixel_values.device
|
||||
)
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
|
||||
)
|
||||
|
||||
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
||||
# The call to `_upad_input` in `_flash_attention_forward` is expensive
|
||||
# So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence),
|
||||
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
|
||||
if not torch.any(~patch_attention_mask):
|
||||
patch_attention_mask = None
|
||||
else:
|
||||
patch_attention_mask = _prepare_4d_attention_mask(
|
||||
patch_attention_mask, hidden_states.dtype
|
||||
)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
last_hidden_state = encoder_outputs
|
||||
last_hidden_state = self.post_layernorm(last_hidden_state)
|
||||
|
||||
return last_hidden_state
|
||||
|
||||
|
||||
class Idefics2MLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
act = config.text_config.hidden_act
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
if "gelu" not in act
|
||||
else lambda x: torch.nn.functional.gelu(
|
||||
x,
|
||||
approximate=(
|
||||
"tanh" if act in ["gelu_fast", "gelu_pytorch_tanh"] else "none"
|
||||
),
|
||||
)
|
||||
)
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
start_shape = hidden_states.shape[:-1]
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
intermediate_size = gate_up_states.shape[-1] // 2
|
||||
gate_up_states = gate_up_states.view(-1, 2, intermediate_size)
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]
|
||||
).view(*start_shape, -1)
|
||||
|
||||
|
||||
class Idefics2RMSNorm(nn.Module):
|
||||
def __init__(self, prefix, weights, eps):
|
||||
"""
|
||||
Idefics2RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(
|
||||
weights.get_tensor(f"{prefix}.weight"), requires_grad=False
|
||||
)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class Idefics2PerceiverAttention(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.layer_idx = None
|
||||
self.hidden_size = config.text_config.hidden_size
|
||||
self.num_heads = config.perceiver_config.resampler_n_heads
|
||||
self.head_size = config.perceiver_config.resampler_head_dim
|
||||
self.num_key_value_heads = config.perceiver_config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.attention_dropout = config.perceiver_config.attention_dropout
|
||||
self.num_heads = self.num_heads // weights.process_group.size()
|
||||
self.num_key_value_heads = (
|
||||
self.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.q_proj = TensorParallelColumnLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.kv = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config=config, prefix=f"{prefix}.o_proj", weights=weights, bias=False
|
||||
)
|
||||
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = latents.size()
|
||||
kv_seq_len = q_len + context.size()[1]
|
||||
|
||||
hidden_states = torch.concat([context, latents], dim=-2)
|
||||
query_states = self.q_proj(latents)
|
||||
kv = self.kv(hidden_states)
|
||||
key_states, value_states = kv.split(
|
||||
[
|
||||
self.head_size * self.num_key_value_heads,
|
||||
self.head_size * self.num_key_value_heads,
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_size
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, kv_seq_len, self.num_key_value_heads, self.head_size
|
||||
).transpose(1, 2)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_size)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_size):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_size)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_size)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class Idefics2PerceiverLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.hidden_size = config.text_config.hidden_size
|
||||
self.n_latents = config.perceiver_config.resampler_n_latents
|
||||
self.depth = config.perceiver_config.resampler_depth
|
||||
self.rms_norm_eps = config.text_config.rms_norm_eps
|
||||
|
||||
self.input_latents_norm = Idefics2RMSNorm(
|
||||
prefix=f"{prefix}.input_latents_norm",
|
||||
weights=weights,
|
||||
eps=self.rms_norm_eps,
|
||||
)
|
||||
self.input_context_norm = Idefics2RMSNorm(
|
||||
prefix=f"{prefix}.input_context_norm",
|
||||
weights=weights,
|
||||
eps=self.rms_norm_eps,
|
||||
)
|
||||
self.self_attn = Idefics2PerceiverAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
)
|
||||
self.post_attention_layernorm = Idefics2RMSNorm(
|
||||
prefix=f"{prefix}.post_attention_layernorm",
|
||||
weights=weights,
|
||||
eps=self.rms_norm_eps,
|
||||
)
|
||||
self.mlp = Idefics2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
"""
|
||||
residual = latents
|
||||
|
||||
latents = self.input_latents_norm(latents)
|
||||
context = self.input_context_norm(context)
|
||||
|
||||
latents = self.self_attn(
|
||||
latents=latents,
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
latents = residual + latents
|
||||
residual = latents
|
||||
|
||||
latents = self.post_attention_layernorm(latents)
|
||||
latents = self.mlp(latents)
|
||||
latents = residual + latents
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class Idefics2PerceiverResampler(nn.Module):
|
||||
def __init__(self, prefix, config, weights) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.text_config.hidden_size
|
||||
self.hidden_act = config.perceiver_config.hidden_act
|
||||
self.n_latents = config.perceiver_config.resampler_n_latents
|
||||
self.depth = config.perceiver_config.resampler_depth
|
||||
self.rms_norm_eps = config.text_config.rms_norm_eps
|
||||
|
||||
# Create Latents for Perceiver
|
||||
self.latents = weights.get_tensor(f"{prefix}.latents")
|
||||
|
||||
# Create Transformer Blocks
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
Idefics2PerceiverLayer(
|
||||
prefix=f"{prefix}.layers.{idx}", config=config, weights=weights
|
||||
)
|
||||
for idx in range(self.depth)
|
||||
]
|
||||
)
|
||||
self.norm = Idefics2RMSNorm(
|
||||
prefix=f"{prefix}.norm",
|
||||
weights=weights,
|
||||
eps=config.text_config.rms_norm_eps,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
attention_mask,
|
||||
) -> torch.Tensor:
|
||||
# seq embed -> bsz seq embed
|
||||
latents = self.latents.unsqueeze(0).expand(
|
||||
(context.shape[0], *self.latents.size())
|
||||
)
|
||||
|
||||
latent_attention_mask = torch.ones(
|
||||
(attention_mask.size(0), latents.size(1)),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
|
||||
attention_mask = _prepare_4d_attention_mask(
|
||||
attention_mask, latents.dtype, tgt_len=self.n_latents
|
||||
)
|
||||
|
||||
compressed_context = latents
|
||||
for perceiver_layer in self.layers:
|
||||
compressed_context = perceiver_layer(
|
||||
compressed_context,
|
||||
context,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
compressed_context = self.norm(compressed_context)
|
||||
|
||||
return compressed_context
|
||||
|
||||
|
||||
class Idefics2Connector(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.modality_projection = Idefics2MLP(
|
||||
prefix=f"{prefix}.modality_projection", config=config, weights=weights
|
||||
)
|
||||
self.perceiver_resampler = Idefics2PerceiverResampler(
|
||||
prefix=f"{prefix}.perceiver_resampler", config=config, weights=weights
|
||||
)
|
||||
|
||||
def forward(self, image_hidden_states, attention_mask):
|
||||
image_hidden_states = self.modality_projection(image_hidden_states)
|
||||
image_hidden_states = self.perceiver_resampler(
|
||||
context=image_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
return image_hidden_states
|
||||
|
||||
|
||||
class Idefics2ForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
config.vision_config.use_medusa = config.use_medusa
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.use_medusa = config.use_medusa
|
||||
|
||||
vision_config = config.vision_config
|
||||
self.text_model = load_text_model(
|
||||
prefix="model" if not prefix else f"{prefix}.model",
|
||||
config=config.text_config,
|
||||
weights=weights,
|
||||
name="text_model",
|
||||
)
|
||||
self.dtype = weights.dtype
|
||||
self.vision_model = Idefics2VisionTransformer(
|
||||
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
|
||||
config=vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
self.connector = Idefics2Connector(
|
||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
self.config = config
|
||||
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
||||
self.image_token_id = config.image_token_id
|
||||
self.pad_token_id = (
|
||||
config.pad_token_id if config.pad_token_id is not None else -1
|
||||
)
|
||||
|
||||
def _merge_input_ids_with_image_features(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
image_features: torch.Tensor,
|
||||
):
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
# mask = input_ids == self.config.image_token_index
|
||||
mask = input_ids == self.config.image_token_id
|
||||
# Let's pray we have enabled enough slots !
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
cu_seqlen_prefill: Optional[torch.Tensor],
|
||||
kv_cache: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
block_tables: torch.Tensor,
|
||||
slots: torch.Tensor,
|
||||
input_lengths: torch.Tensor,
|
||||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
pixel_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
# Unused here
|
||||
image_sizes: Optional[torch.Tensor] = None,
|
||||
):
|
||||
inputs_embeds = self.text_model.embed_tokens(input_ids)
|
||||
if pixel_values is not None:
|
||||
batch_size, num_images, num_channels, height, width = pixel_values.shape
|
||||
all_states = []
|
||||
all_pixel_values = pixel_values
|
||||
all_pixel_mask = pixel_attention_mask
|
||||
for i in range(batch_size):
|
||||
pixel_values = all_pixel_values.to(
|
||||
dtype=self.dtype
|
||||
) # fp16 compatibility
|
||||
pixel_values = pixel_values[i : i + 1]
|
||||
pixel_values = pixel_values.view(num_images, *pixel_values.shape[2:])
|
||||
|
||||
# Remove padding images - padding images are full 0.
|
||||
nb_values_per_image = pixel_values.shape[1:].numel()
|
||||
real_images_inds = (pixel_values == 0.0).sum(
|
||||
dim=(-1, -2, -3)
|
||||
) != nb_values_per_image
|
||||
pixel_values = pixel_values[real_images_inds].contiguous()
|
||||
|
||||
# Handle the vision attention mask
|
||||
if pixel_attention_mask is None:
|
||||
pixel_attention_mask = torch.ones(
|
||||
size=(
|
||||
pixel_values.size(0),
|
||||
pixel_values.size(2),
|
||||
pixel_values.size(3),
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
else:
|
||||
# Remove padding images from the mask/pP p
|
||||
pixel_attention_mask = all_pixel_mask[i : i + 1]
|
||||
pixel_attention_mask = pixel_attention_mask.view(
|
||||
1 * num_images, *pixel_attention_mask.shape[2:]
|
||||
)
|
||||
pixel_attention_mask = pixel_attention_mask[
|
||||
real_images_inds
|
||||
].contiguous()
|
||||
|
||||
patch_size = self.config.vision_config.patch_size
|
||||
patches_subgrid = pixel_attention_mask.unfold(
|
||||
dimension=1, size=patch_size, step=patch_size
|
||||
)
|
||||
patches_subgrid = patches_subgrid.unfold(
|
||||
dimension=2, size=patch_size, step=patch_size
|
||||
)
|
||||
patch_attention_mask = (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
# Get sequence from the vision encoder
|
||||
image_hidden_states = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
# Modality projection & resampling
|
||||
image_hidden_states = self.connector(
|
||||
image_hidden_states,
|
||||
attention_mask=patch_attention_mask.view(pixel_values.size(0), -1),
|
||||
)
|
||||
all_states.append(image_hidden_states)
|
||||
image_hidden_states = torch.stack(all_states, dim=0)
|
||||
# When we generate, we don't want to replace the potential image_token_id that we generated by images
|
||||
# that simply don't exist
|
||||
inputs_embeds = self._merge_input_ids_with_image_features(
|
||||
input_ids, inputs_embeds, image_hidden_states
|
||||
)
|
||||
|
||||
hidden_states = self.text_model.model(
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
logits, speculative_logits = self.text_model.lm_head(hidden_states)
|
||||
return logits, speculative_logits
|
|
@ -23,6 +23,10 @@ from torch import nn
|
|||
from transformers.activations import ACT2FN
|
||||
from transformers.image_processing_utils import select_best_resolution
|
||||
|
||||
from text_generation_server.models.custom_modeling.vlm import (
|
||||
load_text_model,
|
||||
load_vision_model,
|
||||
)
|
||||
from text_generation_server.utils.layers import (
|
||||
TensorParallelColumnLinear,
|
||||
TensorParallelRowLinear,
|
||||
|
@ -105,36 +109,6 @@ class LlavaNextMultiModalProjector(nn.Module):
|
|||
return hidden_states
|
||||
|
||||
|
||||
def load_vision_model(prefix, config, weights):
|
||||
if config.model_type == "clip_vision_model":
|
||||
from text_generation_server.models.custom_modeling.clip import (
|
||||
CLIPVisionTransformer,
|
||||
)
|
||||
|
||||
return CLIPVisionTransformer(
|
||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
def load_text_model(prefix, config, weights):
|
||||
if config.model_type == "llama":
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
|
||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
||||
elif config.model_type == "mistral":
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
)
|
||||
|
||||
return FlashMistralForCausalLM(prefix, config, weights)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
class LlavaNextForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
@ -180,7 +154,12 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = input_ids == self.config.image_token_index
|
||||
# Let's pray we have enabled enough slots !
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
try:
|
||||
inputs_embeds[mask] = image_features.view(-1, image_features.shape[-1])
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot fill images right now. If error happens at warmup, make sure you have enough `--max-input-tokens` to handle images. If error happens at regular runtime, please fill in an issue: {e}"
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(
|
||||
|
@ -196,6 +175,8 @@ class LlavaNextForConditionalGeneration(nn.Module):
|
|||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
pixel_values: torch.FloatTensor = None,
|
||||
# Unused for this model
|
||||
pixel_attention_mask=None,
|
||||
image_sizes: Optional[torch.LongTensor] = None,
|
||||
):
|
||||
inputs_embeds = self.language_model.embed_tokens(input_ids)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
def load_text_model(prefix, config, weights, name=None):
|
||||
if config.model_type == "llama":
|
||||
from text_generation_server.models.custom_modeling.flash_llama_modeling import (
|
||||
FlashLlamaForCausalLM,
|
||||
)
|
||||
|
||||
return FlashLlamaForCausalLM(prefix, config, weights)
|
||||
elif config.model_type == "mistral":
|
||||
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
|
||||
FlashMistralForCausalLM,
|
||||
)
|
||||
|
||||
return FlashMistralForCausalLM(prefix, config, weights, name=name)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
||||
|
||||
|
||||
def load_vision_model(prefix, config, weights):
|
||||
if config.model_type == "clip_vision_model":
|
||||
from text_generation_server.models.custom_modeling.clip import (
|
||||
CLIPVisionTransformer,
|
||||
)
|
||||
|
||||
return CLIPVisionTransformer(
|
||||
prefix=f"{prefix}.vision_model", config=config, weights=weights
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported model type {config.model_type}")
|
|
@ -511,18 +511,33 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
|
||||
if cu_seqlen_prefill is None:
|
||||
logits, speculative_logits = self.compiled_model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
else:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
cu_seqlen_prefill=cu_seqlen_prefill,
|
||||
kv_cache=kv_cache,
|
||||
block_tables=block_tables,
|
||||
slots=slots,
|
||||
input_lengths=input_lengths,
|
||||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
return logits, speculative_logits
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
import torch
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
)
|
||||
from text_generation_server.models.custom_modeling.idefics2 import (
|
||||
Idefics2ForConditionalGeneration,
|
||||
)
|
||||
|
||||
from text_generation_server.models.vlm_causal_lm import VlmCausalLM
|
||||
|
||||
|
||||
class Idefics2(VlmCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str,
|
||||
revision: Optional[str] = None,
|
||||
quantize: Optional[str] = None,
|
||||
use_medusa: Optional[str] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
trust_remote_code: bool = False,
|
||||
):
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
# XXX: Extremely important to cap resolution in order to limit
|
||||
# VRAM usage.
|
||||
size={"longest_edge": 448, "shortest_edge": 378},
|
||||
)
|
||||
super().__init__(
|
||||
model_cls=Idefics2ForConditionalGeneration,
|
||||
model_id=model_id,
|
||||
revision=revision,
|
||||
quantize=quantize,
|
||||
use_medusa=use_medusa,
|
||||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.text_model.model.layers),
|
||||
model.text_model.model.num_key_value_heads,
|
||||
model.text_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.text_model, "max_past", None)
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
|
@ -34,3 +34,13 @@ class LlavaNext(VlmCausalLM):
|
|||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.language_model.model.layers),
|
||||
model.language_model.model.num_key_value_heads,
|
||||
model.language_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.language_model, "max_past", None)
|
||||
|
|
|
@ -64,6 +64,46 @@ def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
|
|||
return height // patch_size, width // patch_size
|
||||
|
||||
|
||||
def image_text_replacement(image_input, config, image_id) -> str:
|
||||
if config.model_type == "idefics2":
|
||||
# TODO technically depends on image splitting which is not implemented.
|
||||
num_features = 320
|
||||
return (
|
||||
"<fake_token_around_image>"
|
||||
+ "<image>" * num_features
|
||||
+ "<fake_token_around_image>"
|
||||
)
|
||||
elif config.model_type == "llava_next":
|
||||
height, width = image_input["image_sizes"][image_id]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
from loguru import logger
|
||||
|
||||
logger.info(f"Found {num_features} in image of resolution {height}x{width}")
|
||||
return "<image>" * num_features
|
||||
else:
|
||||
raise RuntimeError(f"Unknown config {config.model_type} for multimodal")
|
||||
|
||||
|
||||
def get_unpadded_features(
|
||||
height: int, width: int, npatches: int, num_patch_height: int, num_patch_width: int
|
||||
) -> Tuple[int, int]:
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
aspect_ratio: float = width / height
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
current_height = new_height
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
current_width = new_width
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
def get_number_of_features(height: int, width: int, config) -> int:
|
||||
# From config
|
||||
# Hardcoded for CLIP for now
|
||||
|
@ -81,12 +121,9 @@ def get_number_of_features(height: int, width: int, config) -> int:
|
|||
image_grid_pinpoints,
|
||||
image_size,
|
||||
)
|
||||
|
||||
height_of_patch = math.ceil(height / width * npatches)
|
||||
|
||||
unpadded_features = npatches * height_of_patch * num_patch_height * num_patch_width
|
||||
# They are only added after width
|
||||
newline_features = height_of_patch * num_patch_width
|
||||
unpadded_features, newline_features = get_unpadded_features(
|
||||
height, width, npatches, num_patch_height, num_patch_width
|
||||
)
|
||||
# The base patch covers the entire image
|
||||
base_features = npatches**2
|
||||
return unpadded_features + newline_features + base_features
|
||||
|
@ -99,12 +136,9 @@ def load_data_uri(image_uri: str) -> Image.Image:
|
|||
return image
|
||||
|
||||
|
||||
# assert get_number_of_features(889, 1024) == 2634, f"{get_number_of_features(889, 1024)}"
|
||||
# assert get_number_of_features(640, 640) == 2928
|
||||
|
||||
|
||||
class VlmCausalLMBatch(FlashMistralBatch):
|
||||
pixel_values: Optional[List[torch.Tensor]]
|
||||
pixel_attention_mask: Optional[List[torch.Tensor]]
|
||||
image_sizes: Optional[List[Tuple[int, int]]]
|
||||
|
||||
@classmethod
|
||||
|
@ -112,6 +146,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||
def concatenate(cls, batches):
|
||||
batch = super(VlmCausalLMBatch, cls).concatenate(batches)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
|
@ -119,6 +154,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||
def filter(self, request_ids: List[int]):
|
||||
batch = super().filter(request_ids)
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
|
@ -130,6 +166,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||
for r in requests:
|
||||
chunks = split(r.inputs)
|
||||
full_text = ""
|
||||
image_id = 0
|
||||
for chunk in chunks:
|
||||
if chunk["type"] == "text":
|
||||
full_text += chunk["content"]
|
||||
|
@ -147,9 +184,7 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||
"Cannot process input image not starting with data:"
|
||||
)
|
||||
image_input = processor.image_processor(image, return_tensors="pt")
|
||||
height, width = image_input["image_sizes"][0]
|
||||
num_features = get_number_of_features(height, width, config)
|
||||
full_text += "<image>" * num_features
|
||||
full_text += image_text_replacement(image_input, config, image_id)
|
||||
image_inputs.append(image_input)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid chunk type {chunk['type']}")
|
||||
|
@ -161,12 +196,21 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||
batch_inputs, truncation=True, max_length=max_truncation
|
||||
)["input_ids"]
|
||||
if image_inputs:
|
||||
image_inputs = {
|
||||
image_input = image_inputs[0]
|
||||
new_image_inputs = {
|
||||
"pixel_values": torch.cat(
|
||||
[img["pixel_values"] for img in image_inputs], dim=0
|
||||
),
|
||||
"image_sizes": torch.cat([img["image_sizes"] for img in image_inputs]),
|
||||
}
|
||||
if "pixel_attention_mask" in image_input:
|
||||
new_image_inputs["pixel_attention_mask"] = torch.cat(
|
||||
[img["pixel_attention_mask"] for img in image_inputs], dim=0
|
||||
)
|
||||
if "image_sizes" in image_input:
|
||||
new_image_inputs["image_sizes"] = torch.cat(
|
||||
[img["image_sizes"] for img in image_inputs], dim=0
|
||||
)
|
||||
image_inputs = new_image_inputs
|
||||
else:
|
||||
image_inputs = None
|
||||
return batch_tokenized_inputs, image_inputs
|
||||
|
@ -187,9 +231,19 @@ class VlmCausalLMBatch(FlashMistralBatch):
|
|||
batch = cls.from_tokenized(pb, tokenizer, batch_tokenized_inputs, dtype, device)
|
||||
if image_inputs is not None:
|
||||
batch.pixel_values = image_inputs["pixel_values"].to(device=device)
|
||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
if "pixel_attention_mask" in image_inputs:
|
||||
batch.pixel_attention_mask = image_inputs["pixel_attention_mask"].to(
|
||||
device=device
|
||||
)
|
||||
else:
|
||||
batch.pixel_attention_mask = None
|
||||
if "image_sizes" in image_inputs:
|
||||
batch.image_sizes = image_inputs["image_sizes"].to(device=device)
|
||||
else:
|
||||
batch.image_sizes = None
|
||||
else:
|
||||
batch.pixel_values = None
|
||||
batch.pixel_attention_mask = None
|
||||
batch.image_sizes = None
|
||||
return batch
|
||||
|
||||
|
@ -199,16 +253,6 @@ class VlmCausalLM(BaseFlashMistral):
|
|||
def batch_type(self) -> Type[VlmCausalLMBatch]:
|
||||
return VlmCausalLMBatch
|
||||
|
||||
def get_layer_config(self, model) -> Tuple[int, int, int]:
|
||||
return (
|
||||
len(model.language_model.model.layers),
|
||||
model.language_model.model.num_key_value_heads,
|
||||
model.language_model.model.head_size,
|
||||
)
|
||||
|
||||
def max_past(self) -> Optional[int]:
|
||||
return getattr(self.model.language_model, "max_past", None)
|
||||
|
||||
def forward(
|
||||
self, batch: VlmCausalLMBatch
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
@ -270,17 +314,14 @@ class VlmCausalLM(BaseFlashMistral):
|
|||
max_s = min(self.max_past(), max_s)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
padded_bs = bs
|
||||
if bs == 3:
|
||||
padded_bs = 4
|
||||
elif 3 < bs <= 8:
|
||||
padded_bs = 8
|
||||
elif bs > 8:
|
||||
padded_bs = (bs + 7) // 8 * 8
|
||||
|
||||
# Try to find an associated cuda graph
|
||||
cuda_graph = self.cuda_graphs.get(padded_bs, None)
|
||||
|
||||
bs = input_ids.shape[0]
|
||||
sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs])
|
||||
if sorted_padded_bs:
|
||||
# Get associated cuda graph
|
||||
cuda_graph = self.cuda_graphs[sorted_padded_bs[0]]
|
||||
else:
|
||||
cuda_graph = None
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
@ -294,12 +335,15 @@ class VlmCausalLM(BaseFlashMistral):
|
|||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
pixel_values=batch.pixel_values,
|
||||
pixel_attention_mask=batch.pixel_attention_mask,
|
||||
image_sizes=batch.image_sizes,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
batch.prefill_cache_indices = None
|
||||
if batch.pixel_values is not None:
|
||||
batch.pixel_values = None
|
||||
if batch.pixel_attention_mask is not None:
|
||||
batch.pixel_attention_mask = None
|
||||
if batch.image_sizes is not None:
|
||||
batch.image_sizes = None
|
||||
return logits, speculative_logits
|
||||
|
|
Loading…
Reference in New Issue