feat: format code (#1070)
This commit is contained in:
parent
b32e9ce9d5
commit
47954b81e9
|
@ -137,7 +137,7 @@ class Client:
|
||||||
typical_p=typical_p,
|
typical_p=typical_p,
|
||||||
watermark=watermark,
|
watermark=watermark,
|
||||||
decoder_input_details=decoder_input_details,
|
decoder_input_details=decoder_input_details,
|
||||||
top_n_tokens=top_n_tokens
|
top_n_tokens=top_n_tokens,
|
||||||
)
|
)
|
||||||
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
request = Request(inputs=prompt, stream=False, parameters=parameters)
|
||||||
|
|
||||||
|
|
|
@ -133,7 +133,9 @@ class Request(BaseModel):
|
||||||
and parameters.best_of > 1
|
and parameters.best_of > 1
|
||||||
and field_value
|
and field_value
|
||||||
):
|
):
|
||||||
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
|
raise ValidationError(
|
||||||
|
"`best_of` != 1 is not supported when `stream` == True"
|
||||||
|
)
|
||||||
return field_value
|
return field_value
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,11 @@ import pytest
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_awq_handle(launcher):
|
def flash_llama_awq_handle(launcher):
|
||||||
with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=1, quantize="awq") as handle:
|
with launcher(
|
||||||
|
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq",
|
||||||
|
num_shard=1,
|
||||||
|
quantize="awq",
|
||||||
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,6 +16,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
|
||||||
await flash_llama_awq_handle.health(300)
|
await flash_llama_awq_handle.health(300)
|
||||||
return flash_llama_awq_handle.client
|
return flash_llama_awq_handle.client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||||
|
@ -20,11 +25,13 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response.generated_text == "\nWhat is the difference between Deep Learning and Machine"
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "\nWhat is the difference between Deep Learning and Machine"
|
||||||
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||||
|
@ -49,16 +56,18 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_awq_load(
|
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
|
||||||
flash_llama_awq, generate_load, response_snapshot
|
|
||||||
):
|
|
||||||
responses = await generate_load(
|
responses = await generate_load(
|
||||||
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4
|
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == "\nWhat is the difference between Deep Learning and Machine" for r in responses])
|
assert all(
|
||||||
|
[
|
||||||
|
r.generated_text
|
||||||
|
== "\nWhat is the difference between Deep Learning and Machine"
|
||||||
|
for r in responses
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,15 +1,22 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def flash_llama_awq_handle_sharded(launcher):
|
def flash_llama_awq_handle_sharded(launcher):
|
||||||
with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=2, quantize="awq") as handle:
|
with launcher(
|
||||||
|
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq",
|
||||||
|
num_shard=2,
|
||||||
|
quantize="awq",
|
||||||
|
) as handle:
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
|
||||||
await flash_llama_awq_handle_sharded.health(300)
|
await flash_llama_awq_handle_sharded.health(300)
|
||||||
return flash_llama_awq_handle_sharded.client
|
return flash_llama_awq_handle_sharded.client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
|
||||||
|
@ -18,9 +25,13 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
|
||||||
)
|
)
|
||||||
|
|
||||||
assert response.details.generated_tokens == 10
|
assert response.details.generated_tokens == 10
|
||||||
assert response.generated_text == "\nWhat is the difference between Deep Learning and Machine"
|
assert (
|
||||||
|
response.generated_text
|
||||||
|
== "\nWhat is the difference between Deep Learning and Machine"
|
||||||
|
)
|
||||||
assert response == response_snapshot
|
assert response == response_snapshot
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.private
|
@pytest.mark.private
|
||||||
async def test_flash_llama_awq_load_sharded(
|
async def test_flash_llama_awq_load_sharded(
|
||||||
|
@ -31,6 +42,12 @@ async def test_flash_llama_awq_load_sharded(
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
assert all([r.generated_text == "\nWhat is the difference between Deep Learning and Machine" for r in responses])
|
assert all(
|
||||||
|
[
|
||||||
|
r.generated_text
|
||||||
|
== "\nWhat is the difference between Deep Learning and Machine"
|
||||||
|
for r in responses
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
|
@ -3,9 +3,7 @@ import pytest
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
def idefics_handle(launcher):
|
def idefics_handle(launcher):
|
||||||
with launcher(
|
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2) as handle:
|
||||||
"HuggingFaceM4/idefics-9b-instruct", num_shard=2
|
|
||||||
) as handle:
|
|
||||||
yield handle
|
yield handle
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -45,12 +45,15 @@ def test_stopping_criteria_max():
|
||||||
assert criteria(1, "") == (False, None)
|
assert criteria(1, "") == (False, None)
|
||||||
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
|
||||||
|
|
||||||
|
|
||||||
def test_batch_top_tokens():
|
def test_batch_top_tokens():
|
||||||
top_n_tokens = [0, 2, 3, 4, 5]
|
top_n_tokens = [0, 2, 3, 4, 5]
|
||||||
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
top_n_tokens_tensor = torch.tensor(top_n_tokens)
|
||||||
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
|
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
|
||||||
|
|
||||||
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
|
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
|
||||||
|
top_n_tokens, top_n_tokens_tensor, inp_logprobs
|
||||||
|
)
|
||||||
|
|
||||||
assert topn_tok_ids[0] == []
|
assert topn_tok_ids[0] == []
|
||||||
assert topn_tok_ids[1] == [0, 3]
|
assert topn_tok_ids[1] == [0, 3]
|
||||||
|
|
|
@ -125,8 +125,12 @@ def download_weights(
|
||||||
|
|
||||||
if not is_local_model:
|
if not is_local_model:
|
||||||
try:
|
try:
|
||||||
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
|
adapter_config_filename = hf_hub_download(
|
||||||
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
|
model_id, revision=revision, filename="adapter_config.json"
|
||||||
|
)
|
||||||
|
utils.download_and_unload_peft(
|
||||||
|
model_id, revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
is_local_model = True
|
is_local_model = True
|
||||||
utils.weight_files(model_id, revision, extension)
|
utils.weight_files(model_id, revision, extension)
|
||||||
return
|
return
|
||||||
|
@ -179,11 +183,12 @@ def download_weights(
|
||||||
import transformers
|
import transformers
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
if is_local_model:
|
if is_local_model:
|
||||||
config_filename = os.path.join(model_id, "config.json")
|
config_filename = os.path.join(model_id, "config.json")
|
||||||
else:
|
else:
|
||||||
config_filename = hf_hub_download(model_id, revision=revision, filename="config.json")
|
config_filename = hf_hub_download(
|
||||||
|
model_id, revision=revision, filename="config.json"
|
||||||
|
)
|
||||||
with open(config_filename, "r") as f:
|
with open(config_filename, "r") as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
architecture = config["architectures"][0]
|
architecture = config["architectures"][0]
|
||||||
|
|
|
@ -153,7 +153,11 @@ def get_model(
|
||||||
)
|
)
|
||||||
elif model_type == "mpt":
|
elif model_type == "mpt":
|
||||||
return MPTSharded(
|
return MPTSharded(
|
||||||
model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code
|
model_id,
|
||||||
|
revision,
|
||||||
|
quantize=quantize,
|
||||||
|
dtype=dtype,
|
||||||
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif model_type == "gpt_neox":
|
elif model_type == "gpt_neox":
|
||||||
|
@ -252,13 +256,13 @@ def get_model(
|
||||||
)
|
)
|
||||||
elif model_type == "idefics":
|
elif model_type == "idefics":
|
||||||
if FLASH_ATTENTION:
|
if FLASH_ATTENTION:
|
||||||
return IDEFICSSharded(
|
return IDEFICSSharded(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
quantize=quantize,
|
quantize=quantize,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
|
||||||
|
|
||||||
|
@ -269,13 +273,9 @@ def get_model(
|
||||||
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
|
||||||
)
|
)
|
||||||
if quantize == "awq":
|
if quantize == "awq":
|
||||||
raise ValueError(
|
raise ValueError("awq quantization is not supported for AutoModel")
|
||||||
"awq quantization is not supported for AutoModel"
|
|
||||||
)
|
|
||||||
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
|
||||||
raise ValueError(
|
raise ValueError("4bit quantization is not supported for AutoModel")
|
||||||
"4bit quantization is not supported for AutoModel"
|
|
||||||
)
|
|
||||||
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
||||||
return CausalLM(
|
return CausalLM(
|
||||||
model_id,
|
model_id,
|
||||||
|
|
|
@ -643,9 +643,12 @@ class CausalLM(Model):
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text, _, _ = self.decode_token(
|
output_text, _, _ = self.decode_token(
|
||||||
all_input_ids[:, 0],
|
all_input_ids[:, 0],
|
||||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
prefix_offset=len(all_input_ids)
|
||||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
- stopping_criteria.current_tokens
|
||||||
skip_special_tokens=True
|
- 1,
|
||||||
|
read_offset=len(all_input_ids)
|
||||||
|
- stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
|
|
@ -40,7 +40,10 @@ from text_generation_server.utils.layers import (
|
||||||
)
|
)
|
||||||
|
|
||||||
CUSTOM_KERNELS_ENABLED = False
|
CUSTOM_KERNELS_ENABLED = False
|
||||||
if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
if (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
from custom_kernels import fused_bloom_attention_cuda
|
from custom_kernels import fused_bloom_attention_cuda
|
||||||
|
|
||||||
|
|
|
@ -169,6 +169,7 @@ def load_attention(config, prefix, weights):
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_gqa(config, prefix: str, weights):
|
def _load_gqa(config, prefix: str, weights):
|
||||||
assert config.hidden_size % config.num_attention_heads == 0
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
assert config.num_attention_heads % weights.process_group.size() == 0
|
assert config.num_attention_heads % weights.process_group.size() == 0
|
||||||
|
@ -211,7 +212,10 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
|
||||||
# )
|
# )
|
||||||
self.rotary_emb = PositionRotaryEmbedding.static(
|
self.rotary_emb = PositionRotaryEmbedding.static(
|
||||||
config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
|
config=config,
|
||||||
|
dim=self.head_size,
|
||||||
|
base=config.rope_theta,
|
||||||
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
|
@ -20,7 +20,12 @@ import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
|
||||||
from transformers.image_transforms import resize, to_channel_dimension_format, rescale, normalize
|
from transformers.image_transforms import (
|
||||||
|
resize,
|
||||||
|
to_channel_dimension_format,
|
||||||
|
rescale,
|
||||||
|
normalize,
|
||||||
|
)
|
||||||
from transformers.image_utils import (
|
from transformers.image_utils import (
|
||||||
ChannelDimension,
|
ChannelDimension,
|
||||||
ImageInput,
|
ImageInput,
|
||||||
|
@ -121,7 +126,11 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
||||||
a PyTorch tensor of the processed images
|
a PyTorch tensor of the processed images
|
||||||
"""
|
"""
|
||||||
image_size = image_size if image_size is not None else self.image_size
|
image_size = image_size if image_size is not None else self.image_size
|
||||||
image_num_channels = image_num_channels if image_num_channels is not None else self.image_num_channels
|
image_num_channels = (
|
||||||
|
image_num_channels
|
||||||
|
if image_num_channels is not None
|
||||||
|
else self.image_num_channels
|
||||||
|
)
|
||||||
image_mean = image_mean if image_mean is not None else self.image_mean
|
image_mean = image_mean if image_mean is not None else self.image_mean
|
||||||
image_std = image_std if image_std is not None else self.image_std
|
image_std = image_std if image_std is not None else self.image_std
|
||||||
size = (image_size, image_size)
|
size = (image_size, image_size)
|
||||||
|
@ -160,9 +169,13 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
||||||
images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]
|
images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]
|
||||||
images = [self.rescale(image=image, scale=1 / 255) for image in images]
|
images = [self.rescale(image=image, scale=1 / 255) for image in images]
|
||||||
images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
|
images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
|
||||||
images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images]
|
images = [
|
||||||
|
to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images
|
||||||
|
]
|
||||||
# TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
|
# TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
|
||||||
images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
|
images = BatchFeature(
|
||||||
|
data={"pixel_values": images}, tensor_type=TensorType.PYTORCH
|
||||||
|
)["pixel_values"]
|
||||||
|
|
||||||
return images
|
return images
|
||||||
|
|
||||||
|
@ -185,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return Image.open(BytesIO(response.content))
|
return Image.open(BytesIO(response.content))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
|
raise ValueError(
|
||||||
|
f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}"
|
||||||
|
)
|
||||||
|
|
||||||
def rescale(
|
def rescale(
|
||||||
self,
|
self,
|
||||||
|
@ -255,10 +270,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
|
||||||
`np.ndarray`: The normalized image.
|
`np.ndarray`: The normalized image.
|
||||||
"""
|
"""
|
||||||
# TODO 4.32
|
# TODO 4.32
|
||||||
return normalize(
|
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
|
||||||
image, mean=mean, std=std, data_format=data_format, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
transformers.IdeficsImageProcessor = IdeficsImageProcessor
|
transformers.IdeficsImageProcessor = IdeficsImageProcessor
|
||||||
|
|
|
@ -28,7 +28,11 @@ from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, dataclass
|
from transformers.modeling_outputs import (
|
||||||
|
BaseModelOutputWithPast,
|
||||||
|
CausalLMOutputWithPast,
|
||||||
|
dataclass,
|
||||||
|
)
|
||||||
from transformers.modeling_utils import PretrainedConfig
|
from transformers.modeling_utils import PretrainedConfig
|
||||||
from transformers.utils import (
|
from transformers.utils import (
|
||||||
add_start_docstrings,
|
add_start_docstrings,
|
||||||
|
@ -37,8 +41,12 @@ from transformers.utils import (
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
|
||||||
from text_generation_server.models.custom_modeling.idefics_vision import IdeficsVisionTransformer
|
from text_generation_server.models.custom_modeling.idefics_vision import (
|
||||||
from text_generation_server.models.custom_modeling.idefics_perceiver import IdeficsPerceiverResampler
|
IdeficsVisionTransformer,
|
||||||
|
)
|
||||||
|
from text_generation_server.models.custom_modeling.idefics_perceiver import (
|
||||||
|
IdeficsPerceiverResampler,
|
||||||
|
)
|
||||||
from text_generation_server.utils.layers import (
|
from text_generation_server.utils.layers import (
|
||||||
TensorParallelColumnLinear,
|
TensorParallelColumnLinear,
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
|
@ -49,10 +57,12 @@ from text_generation_server.utils.layers import (
|
||||||
)
|
)
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
|
||||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CausalLMOutputWithPastImage(CausalLMOutputWithPast):
|
class CausalLMOutputWithPastImage(CausalLMOutputWithPast):
|
||||||
image_hidden_states: Optional[torch.FloatTensor] = None
|
image_hidden_states: Optional[torch.FloatTensor] = None
|
||||||
|
@ -78,25 +88,39 @@ def expand_inputs_for_generation(
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
):
|
):
|
||||||
expanded_return_idx = (
|
expanded_return_idx = (
|
||||||
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
|
torch.arange(input_ids.shape[0])
|
||||||
|
.view(-1, 1)
|
||||||
|
.repeat(1, expand_size)
|
||||||
|
.view(-1)
|
||||||
|
.to(input_ids.device)
|
||||||
)
|
)
|
||||||
input_ids = input_ids.index_select(0, expanded_return_idx)
|
input_ids = input_ids.index_select(0, expanded_return_idx)
|
||||||
|
|
||||||
if "token_type_ids" in model_kwargs:
|
if "token_type_ids" in model_kwargs:
|
||||||
token_type_ids = model_kwargs["token_type_ids"]
|
token_type_ids = model_kwargs["token_type_ids"]
|
||||||
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
|
model_kwargs["token_type_ids"] = token_type_ids.index_select(
|
||||||
|
0, expanded_return_idx
|
||||||
if attention_mask is not None:
|
)
|
||||||
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
|
|
||||||
model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select(
|
if attention_mask is not None:
|
||||||
|
model_kwargs["attention_mask"] = attention_mask.index_select(
|
||||||
|
0, expanded_return_idx
|
||||||
|
)
|
||||||
|
model_kwargs["image_attention_mask"] = model_kwargs[
|
||||||
|
"image_attention_mask"
|
||||||
|
].index_select(0, expanded_return_idx)
|
||||||
|
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(
|
||||||
0, expanded_return_idx
|
0, expanded_return_idx
|
||||||
)
|
)
|
||||||
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
|
|
||||||
|
|
||||||
if is_encoder_decoder:
|
if is_encoder_decoder:
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None:
|
||||||
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
|
raise ValueError(
|
||||||
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
|
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
|
||||||
|
)
|
||||||
|
encoder_outputs[
|
||||||
|
"last_hidden_state"
|
||||||
|
] = encoder_outputs.last_hidden_state.index_select(
|
||||||
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
|
||||||
)
|
)
|
||||||
model_kwargs["encoder_outputs"] = encoder_outputs
|
model_kwargs["encoder_outputs"] = encoder_outputs
|
||||||
|
@ -120,14 +144,17 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
|
||||||
# update token_type_ids with last value
|
# update token_type_ids with last value
|
||||||
if "token_type_ids" in model_kwargs:
|
if "token_type_ids" in model_kwargs:
|
||||||
token_type_ids = model_kwargs["token_type_ids"]
|
token_type_ids = model_kwargs["token_type_ids"]
|
||||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
model_kwargs["token_type_ids"] = torch.cat(
|
||||||
|
[token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
# update attention masks
|
# update attention masks
|
||||||
if not is_encoder_decoder:
|
if not is_encoder_decoder:
|
||||||
if "attention_mask" in model_kwargs:
|
if "attention_mask" in model_kwargs:
|
||||||
attention_mask = model_kwargs["attention_mask"]
|
attention_mask = model_kwargs["attention_mask"]
|
||||||
model_kwargs["attention_mask"] = torch.cat(
|
model_kwargs["attention_mask"] = torch.cat(
|
||||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
|
||||||
|
dim=-1,
|
||||||
)
|
)
|
||||||
if "image_attention_mask" in model_kwargs:
|
if "image_attention_mask" in model_kwargs:
|
||||||
image_attention_mask = model_kwargs["image_attention_mask"]
|
image_attention_mask = model_kwargs["image_attention_mask"]
|
||||||
|
@ -180,8 +207,12 @@ def freeze_model(model, module_exceptions=[]):
|
||||||
}
|
}
|
||||||
module_exceptions_mapped = [mapping[m] for m in module_exceptions]
|
module_exceptions_mapped = [mapping[m] for m in module_exceptions]
|
||||||
for module in model.modules():
|
for module in model.modules():
|
||||||
if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]):
|
if module_exceptions and any(
|
||||||
module.requires_grad_(True) # Explicitely setting it to true to avoid any mistakes
|
[isinstance(module, t) for t in module_exceptions_mapped]
|
||||||
|
):
|
||||||
|
module.requires_grad_(
|
||||||
|
True
|
||||||
|
) # Explicitely setting it to true to avoid any mistakes
|
||||||
else:
|
else:
|
||||||
module.requires_grad_(False)
|
module.requires_grad_(False)
|
||||||
return model
|
return model
|
||||||
|
@ -195,15 +226,21 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_embeddings = config.vocab_size
|
self.num_embeddings = config.vocab_size
|
||||||
self.weight = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights)
|
self.weight = TensorParallelEmbedding(
|
||||||
self.additional_weight = nn.Parameter(weights.get_tensor(f"model.embed_tokens.additional_embedding.weight"))
|
prefix="model.embed_tokens", weights=weights
|
||||||
|
)
|
||||||
|
self.additional_weight = nn.Parameter(
|
||||||
|
weights.get_tensor(f"model.embed_tokens.additional_embedding.weight")
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, input_ids):
|
def forward(self, input_ids):
|
||||||
# Clone so that we don't modify the original input_ids later on
|
# Clone so that we don't modify the original input_ids later on
|
||||||
input_ids = input_ids.clone()
|
input_ids = input_ids.clone()
|
||||||
additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
|
additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
|
||||||
input_ids_additional_vocab = input_ids[additional_vocab_indices]
|
input_ids_additional_vocab = input_ids[additional_vocab_indices]
|
||||||
additional_embeddings = torch.nn.functional.embedding(input_ids_additional_vocab - self.num_embeddings, self.additional_weight)
|
additional_embeddings = torch.nn.functional.embedding(
|
||||||
|
input_ids_additional_vocab - self.num_embeddings, self.additional_weight
|
||||||
|
)
|
||||||
|
|
||||||
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
|
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
|
||||||
input_ids[additional_vocab_indices] = 0
|
input_ids[additional_vocab_indices] = 0
|
||||||
|
@ -234,7 +271,10 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
||||||
config=config, prefix="lm_head", weights=weights
|
config=config, prefix="lm_head", weights=weights
|
||||||
)
|
)
|
||||||
self.additional_fc = FastLinear.load(
|
self.additional_fc = FastLinear.load(
|
||||||
config=config, prefix="lm_head.additional_fc", weights=weights, bias=False,
|
config=config,
|
||||||
|
prefix="lm_head.additional_fc",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -257,7 +297,10 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
||||||
def _make_causal_mask(
|
def _make_causal_mask(
|
||||||
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
input_ids_shape: torch.Size,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
past_key_values_length: int = 0,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Make causal mask used for bi-directional self-attention.
|
Make causal mask used for bi-directional self-attention.
|
||||||
|
@ -269,8 +312,18 @@ def _make_causal_mask(
|
||||||
mask = mask.to(dtype)
|
mask = mask.to(dtype)
|
||||||
|
|
||||||
if past_key_values_length > 0:
|
if past_key_values_length > 0:
|
||||||
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
mask = torch.cat(
|
||||||
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
[
|
||||||
|
torch.zeros(
|
||||||
|
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||||
|
),
|
||||||
|
mask,
|
||||||
|
],
|
||||||
|
dim=-1,
|
||||||
|
)
|
||||||
|
return mask[None, None, :, :].expand(
|
||||||
|
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||||
|
@ -284,7 +337,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
|
||||||
|
|
||||||
inverted_mask = 1.0 - expanded_mask
|
inverted_mask = 1.0 - expanded_mask
|
||||||
|
|
||||||
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
return inverted_mask.masked_fill(
|
||||||
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IdeficsRMSNorm(nn.Module):
|
class IdeficsRMSNorm(nn.Module):
|
||||||
|
@ -346,7 +401,6 @@ class IdeficsRMSNorm(nn.Module):
|
||||||
if unwrap:
|
if unwrap:
|
||||||
normed_hidden_states = normed_hidden_states.view(*shape)
|
normed_hidden_states = normed_hidden_states.view(*shape)
|
||||||
|
|
||||||
|
|
||||||
return normed_hidden_states
|
return normed_hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@ -367,7 +421,10 @@ class IdeficsMLP(nn.Module):
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.down_proj = TensorParallelRowLinear.load(
|
||||||
config, prefix=f"{prefix}.down_proj", weights=weights, bias=False,
|
config,
|
||||||
|
prefix=f"{prefix}.down_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
)
|
)
|
||||||
self.act_fn = ACT2FN[config.hidden_act]
|
self.act_fn = ACT2FN[config.hidden_act]
|
||||||
|
|
||||||
|
@ -375,7 +432,9 @@ class IdeficsMLP(nn.Module):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states)
|
||||||
shape = gate_up_states.shape
|
shape = gate_up_states.shape
|
||||||
gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2)
|
gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2)
|
||||||
return self.down_proj(self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1])
|
return self.down_proj(
|
||||||
|
self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# this was adapted from LlamaAttention
|
# this was adapted from LlamaAttention
|
||||||
|
@ -445,14 +504,22 @@ class IdeficsAttention(nn.Module):
|
||||||
self.qk_layer_norms = qk_layer_norms
|
self.qk_layer_norms = qk_layer_norms
|
||||||
if self.qk_layer_norms:
|
if self.qk_layer_norms:
|
||||||
self.q_layer_norm = IdeficsRMSNorm(
|
self.q_layer_norm = IdeficsRMSNorm(
|
||||||
prefix=f"{prefix}.q_layer_norm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.q_layer_norm",
|
||||||
)
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
self.k_layer_norm = IdeficsRMSNorm(
|
self.k_layer_norm = IdeficsRMSNorm(
|
||||||
prefix=f"{prefix}.q_layer_norm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.q_layer_norm",
|
||||||
)
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return (
|
||||||
|
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -470,20 +537,42 @@ class IdeficsAttention(nn.Module):
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
if is_cross_attention:
|
if is_cross_attention:
|
||||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
query_states = self.q_proj(hidden_states).view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
) # .transpose(1, 2)
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
|
(
|
||||||
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
_,
|
||||||
|
kv_len,
|
||||||
|
_,
|
||||||
|
) = (
|
||||||
|
key_value_states.size()
|
||||||
|
) # Note that, in this case, `kv_len` == `kv_seq_len`
|
||||||
|
key_states = (
|
||||||
|
self.k_proj(key_value_states)
|
||||||
|
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
)
|
||||||
value_states = (
|
value_states = (
|
||||||
self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
|
self.v_proj(key_value_states)
|
||||||
|
.view(bsz, kv_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
qkv = self.qkv(hidden_states)
|
qkv = self.qkv(hidden_states)
|
||||||
query_states, key_states, value_states = qkv.split(self.num_heads * self.head_dim, dim=2)
|
query_states, key_states, value_states = qkv.split(
|
||||||
|
self.num_heads * self.head_dim, dim=2
|
||||||
|
)
|
||||||
|
|
||||||
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
query_states = query_states.view(
|
||||||
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2)
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
|
) # .transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
) # . transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
) # .transpose(1, 2)
|
||||||
kv_seq_len = q_len
|
kv_seq_len = q_len
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
kv_seq_len += past_key_value[0].shape[-2]
|
||||||
|
@ -493,10 +582,14 @@ class IdeficsAttention(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
shape = query_states.shape
|
shape = query_states.shape
|
||||||
query_states = self.rotary_emb(query_states.view(-1, *shape[2:]), cos, sin).view(shape)
|
query_states = self.rotary_emb(
|
||||||
|
query_states.view(-1, *shape[2:]), cos, sin
|
||||||
|
).view(shape)
|
||||||
|
|
||||||
shape = key_states.shape
|
shape = key_states.shape
|
||||||
key_states = self.rotary_emb(key_states.reshape(-1, *shape[2:]), cos, sin).view(shape)
|
key_states = self.rotary_emb(
|
||||||
|
key_states.reshape(-1, *shape[2:]), cos, sin
|
||||||
|
).view(shape)
|
||||||
|
|
||||||
query_states = query_states.transpose(1, 2)
|
query_states = query_states.transpose(1, 2)
|
||||||
key_states = key_states.transpose(1, 2)
|
key_states = key_states.transpose(1, 2)
|
||||||
|
@ -571,8 +664,14 @@ class IdeficsDecoderLayer(nn.Module):
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.input_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
|
self.input_layernorm = IdeficsRMSNorm(
|
||||||
self.post_attention_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps)
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = IdeficsRMSNorm(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -583,7 +682,9 @@ class IdeficsDecoderLayer(nn.Module):
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[
|
||||||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
@ -650,14 +751,22 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
|
||||||
prefix=f"{prefix}.mlp",
|
prefix=f"{prefix}.mlp",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.input_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
|
self.input_layernorm = IdeficsRMSNorm(
|
||||||
self.post_attention_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps)
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
self.post_attention_layernorm = IdeficsRMSNorm(
|
||||||
|
prefix=f"{prefix}.post_attention_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.rms_norm_eps,
|
||||||
|
)
|
||||||
self.config = config.dropout
|
self.config = config.dropout
|
||||||
|
|
||||||
self.act_cross_attn = nn.Tanh()
|
self.act_cross_attn = nn.Tanh()
|
||||||
self.act_dense = nn.Tanh()
|
self.act_dense = nn.Tanh()
|
||||||
|
|
||||||
self.alpha_cross_attn = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_cross_attn"))
|
self.alpha_cross_attn = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.alpha_cross_attn")
|
||||||
|
)
|
||||||
self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense"))
|
self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense"))
|
||||||
|
|
||||||
if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
|
if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
|
||||||
|
@ -673,7 +782,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
|
||||||
use_cache: Optional[bool] = False,
|
use_cache: Optional[bool] = False,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
no_images: Optional[bool] = False,
|
no_images: Optional[bool] = False,
|
||||||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
) -> Tuple[
|
||||||
|
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||||
|
]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
@ -695,7 +806,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.")
|
raise NotImplementedError(
|
||||||
|
"Past key value states are not implemented for Idefics cross attention module."
|
||||||
|
)
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
|
@ -711,7 +824,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
|
||||||
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
|
||||||
# when there are no images the model is used in pure language mode
|
# when there are no images the model is used in pure language mode
|
||||||
gate = 0 if no_images else 1
|
gate = 0 if no_images else 1
|
||||||
hidden_states = residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states
|
hidden_states = (
|
||||||
|
residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states
|
||||||
|
)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
@ -896,11 +1011,14 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
self.gated_cross_attn_layers = nn.ModuleList(
|
self.gated_cross_attn_layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
IdeficsGatedCrossAttentionLayer(layer_id, config, weights)
|
IdeficsGatedCrossAttentionLayer(layer_id, config, weights)
|
||||||
for layer_id in range(num_cross_layers)]
|
for layer_id in range(num_cross_layers)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
# self.gradient_checkpointing = False
|
# self.gradient_checkpointing = False
|
||||||
|
|
||||||
self.norm = IdeficsRMSNorm(prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps)
|
self.norm = IdeficsRMSNorm(
|
||||||
|
prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
# self.gradient_checkpointing = False
|
# self.gradient_checkpointing = False
|
||||||
# Initialize weights and apply final processing
|
# Initialize weights and apply final processing
|
||||||
|
@ -932,7 +1050,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
# self.embed_tokens = value
|
# self.embed_tokens = value
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
|
||||||
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
|
def _prepare_decoder_attention_mask(
|
||||||
|
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
|
||||||
|
):
|
||||||
# create causal mask
|
# create causal mask
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
combined_attention_mask = None
|
combined_attention_mask = None
|
||||||
|
@ -946,11 +1066,13 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
||||||
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
|
expanded_attn_mask = _expand_mask(
|
||||||
inputs_embeds.device
|
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
|
||||||
)
|
).to(inputs_embeds.device)
|
||||||
combined_attention_mask = (
|
combined_attention_mask = (
|
||||||
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
|
expanded_attn_mask
|
||||||
|
if combined_attention_mask is None
|
||||||
|
else expanded_attn_mask + combined_attention_mask
|
||||||
)
|
)
|
||||||
|
|
||||||
return combined_attention_mask
|
return combined_attention_mask
|
||||||
|
@ -974,23 +1096,35 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
) -> Union[Tuple, BaseModelOutputWithPastImage]:
|
) -> Union[Tuple, BaseModelOutputWithPastImage]:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
)
|
)
|
||||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
# retrieve input_ids and inputs_embeds
|
# retrieve input_ids and inputs_embeds
|
||||||
if input_ids is not None and inputs_embeds is not None:
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
elif input_ids is not None:
|
elif input_ids is not None:
|
||||||
batch_size, seq_length = input_ids.shape
|
batch_size, seq_length = input_ids.shape
|
||||||
elif inputs_embeds is not None:
|
elif inputs_embeds is not None:
|
||||||
batch_size, seq_length, _ = inputs_embeds.shape
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError(
|
||||||
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||||
|
)
|
||||||
|
|
||||||
seq_length_with_past = seq_length
|
seq_length_with_past = seq_length
|
||||||
past_key_values_length = 0
|
past_key_values_length = 0
|
||||||
|
@ -1006,7 +1140,10 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
elif position_ids is None:
|
elif position_ids is None:
|
||||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
position_ids = torch.arange(
|
position_ids = torch.arange(
|
||||||
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
else:
|
else:
|
||||||
|
@ -1016,29 +1153,52 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
|
|
||||||
if image_hidden_states is None:
|
if image_hidden_states is None:
|
||||||
if pixel_values is None and image_embeddings is None:
|
if pixel_values is None and image_embeddings is None:
|
||||||
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
|
raise ValueError(
|
||||||
|
"Either pixel_values and image_embeddings have to be not-None."
|
||||||
|
)
|
||||||
|
|
||||||
elif pixel_values is not None and image_embeddings is not None:
|
elif pixel_values is not None and image_embeddings is not None:
|
||||||
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
|
raise ValueError(
|
||||||
|
"You cannot specify both pixel_values and image_embeddings at the same time"
|
||||||
|
)
|
||||||
|
|
||||||
elif pixel_values is not None:
|
elif pixel_values is not None:
|
||||||
no_images = len(torch.nonzero(pixel_values)) == 0
|
no_images = len(torch.nonzero(pixel_values)) == 0
|
||||||
pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility
|
pixel_values = pixel_values.to(
|
||||||
|
dtype=self.dtype, device=device
|
||||||
|
) # fp16 compatibility
|
||||||
batch_size, num_images = pixel_values.shape[:2]
|
batch_size, num_images = pixel_values.shape[:2]
|
||||||
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
|
pixel_values = pixel_values.contiguous().view(
|
||||||
|
batch_size * num_images, *pixel_values.shape[2:]
|
||||||
|
)
|
||||||
|
|
||||||
# Get sequence from the vision encoder
|
# Get sequence from the vision encoder
|
||||||
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
|
image_hidden_states = self.vision_model(
|
||||||
|
pixel_values=pixel_values
|
||||||
|
).last_hidden_state
|
||||||
|
|
||||||
elif image_embeddings is not None:
|
elif image_embeddings is not None:
|
||||||
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
|
(
|
||||||
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
|
batch_size,
|
||||||
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
|
num_images,
|
||||||
|
image_seq_len,
|
||||||
|
image_hidden_size,
|
||||||
|
) = image_embeddings.size()
|
||||||
|
image_hidden_states = image_embeddings.to(
|
||||||
|
dtype=self.dtype, device=input_ids.device
|
||||||
|
)
|
||||||
|
image_hidden_states = image_hidden_states.view(
|
||||||
|
batch_size * num_images, image_seq_len, image_hidden_size
|
||||||
|
)
|
||||||
|
|
||||||
if self.config.use_resampler:
|
if self.config.use_resampler:
|
||||||
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
image_hidden_states = self.perceiver_resampler(image_hidden_states)
|
||||||
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
|
image_seq_len, image_hidden_size = image_hidden_states.size(
|
||||||
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
|
1
|
||||||
|
), image_hidden_states.size(2)
|
||||||
|
image_hidden_states = image_hidden_states.view(
|
||||||
|
batch_size, num_images * image_seq_len, image_hidden_size
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
no_images = False
|
no_images = False
|
||||||
num_images = pixel_values.shape[1]
|
num_images = pixel_values.shape[1]
|
||||||
|
@ -1050,7 +1210,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
text_seq_len = image_attention_mask.size(1)
|
text_seq_len = image_attention_mask.size(1)
|
||||||
image_attention_mask = image_attention_mask.unsqueeze(-1)
|
image_attention_mask = image_attention_mask.unsqueeze(-1)
|
||||||
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
|
||||||
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
|
image_attention_mask = image_attention_mask.view(
|
||||||
|
batch_size, text_seq_len, num_images * image_seq_len
|
||||||
|
)
|
||||||
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
|
||||||
image_hidden_shape = (image_batch_size, image_sequence_length)
|
image_hidden_shape = (image_batch_size, image_sequence_length)
|
||||||
if image_attention_mask is None:
|
if image_attention_mask is None:
|
||||||
|
@ -1060,7 +1222,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
# if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
|
# if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
|
||||||
# raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
|
# raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
|
||||||
|
|
||||||
|
|
||||||
# if image_hidden_states is not None:
|
# if image_hidden_states is not None:
|
||||||
# else:
|
# else:
|
||||||
# image_attention_mask = None
|
# image_attention_mask = None
|
||||||
|
@ -1070,10 +1231,15 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
# embed positions
|
# embed positions
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = torch.ones(
|
attention_mask = torch.ones(
|
||||||
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
(batch_size, seq_length_with_past),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=inputs_embeds.device,
|
||||||
)
|
)
|
||||||
attention_mask = self._prepare_decoder_attention_mask(
|
attention_mask = self._prepare_decoder_attention_mask(
|
||||||
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
@ -1094,7 +1260,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
past_key_value = (
|
||||||
|
past_key_values[idx] if past_key_values is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
def vblock(
|
def vblock(
|
||||||
main_block,
|
main_block,
|
||||||
|
@ -1194,7 +1362,11 @@ class IdeficsModel(IdeficsPreTrainedModel):
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = next_decoder_cache if use_cache else None
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
return BaseModelOutputWithPastImage(
|
return BaseModelOutputWithPastImage(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=next_cache,
|
past_key_values=next_cache,
|
||||||
|
@ -1230,7 +1402,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
pixel_values: Optional[torch.FloatTensor] = None,
|
pixel_values: Optional[torch.FloatTensor] = None,
|
||||||
image_embeddings: Optional[torch.FloatTensor] = None,
|
image_embeddings: Optional[torch.FloatTensor] = None,
|
||||||
image_hidden_states: Optional[torch.FloatTensor] = None,
|
image_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
image_attention_mask: Optional[torch.Tensor] = None,
|
image_attention_mask: Optional[torch.Tensor] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
|
@ -1264,11 +1436,19 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||||
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
|
||||||
```"""
|
```"""
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = (
|
||||||
output_hidden_states = (
|
output_attentions
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
outputs = self.model(
|
outputs = self.model(
|
||||||
|
@ -1298,7 +1478,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||||
past_key_values=outputs.past_key_values,
|
past_key_values=outputs.past_key_values,
|
||||||
hidden_states=outputs.hidden_states,
|
hidden_states=outputs.hidden_states,
|
||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
image_hidden_states=outputs.image_hidden_states
|
image_hidden_states=outputs.image_hidden_states,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||||
|
@ -1316,12 +1496,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
|
||||||
return expand_inputs_for_generation(*args, **model_kwargs)
|
return expand_inputs_for_generation(*args, **model_kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
|
def _update_model_kwargs_for_generation(
|
||||||
return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
|
outputs, model_kwargs, is_encoder_decoder=False
|
||||||
|
):
|
||||||
|
return update_model_kwargs_for_generation(
|
||||||
|
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _reorder_cache(past, beam_idx):
|
def _reorder_cache(past, beam_idx):
|
||||||
reordered_past = ()
|
reordered_past = ()
|
||||||
for layer_past in past:
|
for layer_past in past:
|
||||||
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
|
reordered_past += (
|
||||||
|
tuple(
|
||||||
|
past_state.index_select(0, beam_idx) for past_state in layer_past
|
||||||
|
),
|
||||||
|
)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
|
@ -46,7 +46,8 @@ from text_generation_server.utils.layers import (
|
||||||
TensorParallelRowLinear,
|
TensorParallelRowLinear,
|
||||||
)
|
)
|
||||||
|
|
||||||
EPS=1e-5
|
EPS = 1e-5
|
||||||
|
|
||||||
|
|
||||||
class IdeficsPerceiverResampler(nn.Module):
|
class IdeficsPerceiverResampler(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -78,7 +79,12 @@ class IdeficsPerceiverResampler(nn.Module):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
|
self.embed_dim, self.n_heads, self.head_dim, self.n_latents = (
|
||||||
|
embed_dim,
|
||||||
|
n_heads,
|
||||||
|
head_dim,
|
||||||
|
n_latents,
|
||||||
|
)
|
||||||
self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
|
self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
|
||||||
|
|
||||||
# Create Latents for Perceiver
|
# Create Latents for Perceiver
|
||||||
|
@ -107,14 +113,16 @@ class IdeficsPerceiverResampler(nn.Module):
|
||||||
prefix=f"{prefix}.blocks.{layer_id}.1",
|
prefix=f"{prefix}.blocks.{layer_id}.1",
|
||||||
intermediate_size=self.intermediate_dim,
|
intermediate_size=self.intermediate_dim,
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights
|
weights=weights,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
for layer_id in range(depth)
|
for layer_id in range(depth)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS)
|
self.layer_norm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, context: torch.Tensor) -> torch.Tensor:
|
def forward(self, context: torch.Tensor) -> torch.Tensor:
|
||||||
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
|
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
|
||||||
|
@ -130,25 +138,34 @@ class IdeficsPerceiverResampler(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class IdeficsPerceiverAttention(nn.Module):
|
class IdeficsPerceiverAttention(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
prefix,
|
self,
|
||||||
config,
|
prefix,
|
||||||
embed_dim: int,
|
config,
|
||||||
n_heads: int,
|
embed_dim: int,
|
||||||
head_dim: int,
|
n_heads: int,
|
||||||
qk_layer_norms: bool,
|
head_dim: int,
|
||||||
weights
|
qk_layer_norms: bool,
|
||||||
) -> None:
|
weights,
|
||||||
|
) -> None:
|
||||||
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
|
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
|
self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
|
||||||
self.qk_layer_norms = qk_layer_norms
|
self.qk_layer_norms = qk_layer_norms
|
||||||
# Normalization & Scaling
|
# Normalization & Scaling
|
||||||
self.context_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS)
|
self.context_layer_norm = nn.LayerNorm.load(
|
||||||
self.latents_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS)
|
prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
self.latents_layer_norm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
if self.qk_layer_norms:
|
if self.qk_layer_norms:
|
||||||
self.q_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS)
|
self.q_layer_norm = nn.LayerNorm.load(
|
||||||
self.k_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS)
|
prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
self.k_layer_norm = nn.LayerNorm.load(
|
||||||
|
prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS
|
||||||
|
)
|
||||||
|
|
||||||
self.qk_scale = self.head_dim**-0.5
|
self.qk_scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
@ -164,10 +181,10 @@ class IdeficsPerceiverAttention(nn.Module):
|
||||||
self.q_proj = TensorParallelColumnLinear.load(
|
self.q_proj = TensorParallelColumnLinear.load(
|
||||||
config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
|
config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
self.k_proj = TensorParallelColumnLinear.load(
|
self.k_proj = TensorParallelColumnLinear.load(
|
||||||
config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
|
config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
self.v_proj = TensorParallelColumnLinear.load(
|
self.v_proj = TensorParallelColumnLinear.load(
|
||||||
config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
|
config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -202,7 +219,12 @@ class IdeficsPerceiverAttention(nn.Module):
|
||||||
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
|
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
|
||||||
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
|
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
|
||||||
# einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
|
# einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
|
||||||
q, k, v = [x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(1, 2) for x in (q, k, v)]
|
q, k, v = [
|
||||||
|
x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(
|
||||||
|
1, 2
|
||||||
|
)
|
||||||
|
for x in (q, k, v)
|
||||||
|
]
|
||||||
|
|
||||||
if self.qk_layer_norms:
|
if self.qk_layer_norms:
|
||||||
q = self.q_layer_norm(q)
|
q = self.q_layer_norm(q)
|
||||||
|
@ -219,25 +241,34 @@ class IdeficsPerceiverAttention(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class IdeficsMLP(nn.Module):
|
class IdeficsMLP(nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
prefix,
|
self,
|
||||||
intermediate_size,
|
prefix,
|
||||||
config,
|
intermediate_size,
|
||||||
weights,
|
config,
|
||||||
):
|
weights,
|
||||||
|
):
|
||||||
"""Simple MLP block with intermediate_size and embedding size"""
|
"""Simple MLP block with intermediate_size and embedding size"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.vision_config.embed_dim
|
self.embed_dim = config.vision_config.embed_dim
|
||||||
self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS)
|
self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS)
|
||||||
self.fc = TensorParallelColumnLinear.load(
|
self.fc = TensorParallelColumnLinear.load(
|
||||||
config=config, prefix=f"{prefix}.fc", weights=weights, bias=False,
|
config=config,
|
||||||
|
prefix=f"{prefix}.fc",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
)
|
)
|
||||||
self.act = nn.ReLU()
|
self.act = nn.ReLU()
|
||||||
self.c_proj = TensorParallelRowLinear.load(
|
self.c_proj = TensorParallelRowLinear.load(
|
||||||
config=config, prefix=f"{prefix}.c_proj", weights=weights, bias=False,
|
config=config,
|
||||||
|
prefix=f"{prefix}.c_proj",
|
||||||
|
weights=weights,
|
||||||
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
|
def forward(
|
||||||
|
self, hidden_states: Optional[Tuple[torch.FloatTensor]]
|
||||||
|
) -> torch.FloatTensor:
|
||||||
hidden_states = self.ln(hidden_states)
|
hidden_states = self.ln(hidden_states)
|
||||||
hidden_states = self.fc(hidden_states)
|
hidden_states = self.fc(hidden_states)
|
||||||
hidden_states = self.act(hidden_states)
|
hidden_states = self.act(hidden_states)
|
||||||
|
|
|
@ -21,9 +21,16 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
from transformers.feature_extraction_utils import BatchFeature
|
from transformers.feature_extraction_utils import BatchFeature
|
||||||
from transformers.processing_utils import ProcessorMixin
|
from transformers.processing_utils import ProcessorMixin
|
||||||
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
|
from transformers.tokenization_utils_base import (
|
||||||
|
BatchEncoding,
|
||||||
|
PaddingStrategy,
|
||||||
|
TextInput,
|
||||||
|
TruncationStrategy,
|
||||||
|
)
|
||||||
from transformers.utils import TensorType, is_torch_available
|
from transformers.utils import TensorType, is_torch_available
|
||||||
from text_generation_server.models.custom_modeling.idefics_image_processing import IdeficsImageProcessor
|
from text_generation_server.models.custom_modeling.idefics_image_processing import (
|
||||||
|
IdeficsImageProcessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
@ -124,7 +131,14 @@ class IdeficsProcessor(ProcessorMixin):
|
||||||
image_processor_class = "IdeficsImageProcessor"
|
image_processor_class = "IdeficsImageProcessor"
|
||||||
tokenizer_class = "LlamaTokenizerFast"
|
tokenizer_class = "LlamaTokenizerFast"
|
||||||
|
|
||||||
def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_utterance_token=None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
image_processor,
|
||||||
|
tokenizer=None,
|
||||||
|
image_size=224,
|
||||||
|
add_end_of_utterance_token=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
if image_processor is None:
|
if image_processor is None:
|
||||||
raise ValueError("You need to specify an `image_processor`.")
|
raise ValueError("You need to specify an `image_processor`.")
|
||||||
if tokenizer is None:
|
if tokenizer is None:
|
||||||
|
@ -142,7 +156,8 @@ class IdeficsProcessor(ProcessorMixin):
|
||||||
|
|
||||||
self.tokenizer_was_trained_with_end_of_utterance_token = (
|
self.tokenizer_was_trained_with_end_of_utterance_token = (
|
||||||
True
|
True
|
||||||
if "<end_of_utterance>" in self.tokenizer.special_tokens_map.get("additional_special_tokens", [])
|
if "<end_of_utterance>"
|
||||||
|
in self.tokenizer.special_tokens_map.get("additional_special_tokens", [])
|
||||||
else False
|
else False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -265,7 +280,9 @@ class IdeficsProcessor(ProcessorMixin):
|
||||||
|
|
||||||
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
|
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
|
||||||
if add_end_of_utterance_token is None:
|
if add_end_of_utterance_token is None:
|
||||||
add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
|
add_end_of_utterance_token = (
|
||||||
|
self.tokenizer_was_trained_with_end_of_utterance_token
|
||||||
|
)
|
||||||
|
|
||||||
# turn non-batched prompts into batched
|
# turn non-batched prompts into batched
|
||||||
if not any(isinstance(i, list) for i in prompts):
|
if not any(isinstance(i, list) for i in prompts):
|
||||||
|
@ -358,10 +375,14 @@ class IdeficsProcessor(ProcessorMixin):
|
||||||
current_images = images[:local_max_num_images]
|
current_images = images[:local_max_num_images]
|
||||||
|
|
||||||
if len(current_images) > 0:
|
if len(current_images) > 0:
|
||||||
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
|
padded_image_tensor = torch.zeros(
|
||||||
|
max_num_images, *current_images.size()[1:]
|
||||||
|
)
|
||||||
padded_image_tensor[: current_images.size(0)] = current_images
|
padded_image_tensor[: current_images.size(0)] = current_images
|
||||||
else:
|
else:
|
||||||
padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims)
|
padded_image_tensor = torch.zeros(
|
||||||
|
max_num_images, *self.default_image_dims
|
||||||
|
)
|
||||||
|
|
||||||
output_images.append(padded_image_tensor)
|
output_images.append(padded_image_tensor)
|
||||||
output_input_ids.append(torch.tensor(padded_input_ids))
|
output_input_ids.append(torch.tensor(padded_input_ids))
|
||||||
|
@ -373,14 +394,19 @@ class IdeficsProcessor(ProcessorMixin):
|
||||||
output_attention_masks = torch.stack(output_attention_masks)
|
output_attention_masks = torch.stack(output_attention_masks)
|
||||||
|
|
||||||
if at_least_one_image:
|
if at_least_one_image:
|
||||||
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, self.tokenizer)
|
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(
|
||||||
|
output_input_ids, self.tokenizer
|
||||||
|
)
|
||||||
image_attention_mask = incremental_to_binary_attention_mask(
|
image_attention_mask = incremental_to_binary_attention_mask(
|
||||||
image_attention_mask, num_classes=max_num_images
|
image_attention_mask, num_classes=max_num_images
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# in full language mode we set the image mask to all-0s
|
# in full language mode we set the image mask to all-0s
|
||||||
image_attention_mask = torch.zeros(
|
image_attention_mask = torch.zeros(
|
||||||
output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool
|
output_input_ids.shape[0],
|
||||||
|
output_input_ids.shape[1],
|
||||||
|
1,
|
||||||
|
dtype=torch.bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
return BatchFeature(
|
return BatchFeature(
|
||||||
|
|
|
@ -75,7 +75,9 @@ class IdeficsVisionEmbeddings(nn.Module):
|
||||||
self.image_size = config.image_size
|
self.image_size = config.image_size
|
||||||
self.patch_size = config.patch_size
|
self.patch_size = config.patch_size
|
||||||
|
|
||||||
self.class_embedding = nn.Parameter(weights.get_tensor(f"{prefix}.class_embedding"))
|
self.class_embedding = nn.Parameter(
|
||||||
|
weights.get_tensor(f"{prefix}.class_embedding")
|
||||||
|
)
|
||||||
|
|
||||||
self.patch_embedding = nn.Conv2d.load_no_bias(
|
self.patch_embedding = nn.Conv2d.load_no_bias(
|
||||||
prefix=f"{prefix}.patch_embedding",
|
prefix=f"{prefix}.patch_embedding",
|
||||||
|
@ -91,12 +93,16 @@ class IdeficsVisionEmbeddings(nn.Module):
|
||||||
self.position_embedding = TensorParallelEmbedding(
|
self.position_embedding = TensorParallelEmbedding(
|
||||||
prefix="model.vision_model.embeddings.position_embedding", weights=weights
|
prefix="model.vision_model.embeddings.position_embedding", weights=weights
|
||||||
)
|
)
|
||||||
self.position_ids = torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
|
self.position_ids = (
|
||||||
|
torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||||
batch_size = pixel_values.shape[0]
|
batch_size = pixel_values.shape[0]
|
||||||
target_dtype = self.patch_embedding.weight.dtype
|
target_dtype = self.patch_embedding.weight.dtype
|
||||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
patch_embeds = self.patch_embedding(
|
||||||
|
pixel_values.to(dtype=target_dtype)
|
||||||
|
) # shape = [*, width, grid, grid]
|
||||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||||
|
@ -132,7 +138,6 @@ class IdeficsVisionAttention(nn.Module):
|
||||||
self.num_heads = self.num_heads // weights.process_group.size()
|
self.num_heads = self.num_heads // weights.process_group.size()
|
||||||
self.embed_dim = self.embed_dim // weights.process_group.size()
|
self.embed_dim = self.embed_dim // weights.process_group.size()
|
||||||
|
|
||||||
|
|
||||||
self.k_proj = TensorParallelColumnLinear.load(
|
self.k_proj = TensorParallelColumnLinear.load(
|
||||||
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
|
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
|
||||||
)
|
)
|
||||||
|
@ -147,7 +152,11 @@ class IdeficsVisionAttention(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return (
|
||||||
|
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
||||||
|
.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -186,7 +195,10 @@ class IdeficsVisionAttention(nn.Module):
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
|
||||||
f" {causal_attention_mask.size()}"
|
f" {causal_attention_mask.size()}"
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
attn_weights = (
|
||||||
|
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
+ causal_attention_mask
|
||||||
|
)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
|
@ -194,7 +206,10 @@ class IdeficsVisionAttention(nn.Module):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
attn_weights = (
|
||||||
|
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||||
|
+ attention_mask
|
||||||
|
)
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
@ -204,12 +219,18 @@ class IdeficsVisionAttention(nn.Module):
|
||||||
# make sure that attn_weights keeps its gradient.
|
# make sure that attn_weights keeps its gradient.
|
||||||
# In order to do so, attn_weights have to reshaped
|
# In order to do so, attn_weights have to reshaped
|
||||||
# twice and have to be reused in the following
|
# twice and have to be reused in the following
|
||||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights_reshaped = attn_weights.view(
|
||||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
bsz, self.num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
|
attn_weights = attn_weights_reshaped.view(
|
||||||
|
bsz * self.num_heads, tgt_len, src_len
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
attn_weights_reshaped = None
|
attn_weights_reshaped = None
|
||||||
|
|
||||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
attn_probs = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
attn_output = torch.bmm(attn_probs, value_states)
|
||||||
|
|
||||||
|
@ -253,11 +274,15 @@ class IdeficsVisionEncoderLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.hidden_size
|
self.embed_dim = config.hidden_size
|
||||||
self.self_attn = IdeficsVisionAttention(prefix=f"{prefix}.self_attn", config=config, weights=weights)
|
self.self_attn = IdeficsVisionAttention(
|
||||||
|
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||||
|
)
|
||||||
self.layer_norm1 = nn.LayerNorm.load(
|
self.layer_norm1 = nn.LayerNorm.load(
|
||||||
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
self.mlp = IdeficsVisionMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = IdeficsVisionMLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights
|
||||||
|
)
|
||||||
self.layer_norm2 = nn.LayerNorm.load(
|
self.layer_norm2 = nn.LayerNorm.load(
|
||||||
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
|
@ -318,7 +343,11 @@ class IdeficsVisionEncoder(nn.Module):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[
|
[
|
||||||
IdeficsVisionEncoderLayer(prefix=f"{prefix}.encoder.layers.{layer_id}", config=config, weights=weights)
|
IdeficsVisionEncoderLayer(
|
||||||
|
prefix=f"{prefix}.encoder.layers.{layer_id}",
|
||||||
|
config=config,
|
||||||
|
weights=weights,
|
||||||
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -362,11 +391,19 @@ class IdeficsVisionEncoder(nn.Module):
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = (
|
||||||
output_hidden_states = (
|
output_attentions
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
encoder_states = () if output_hidden_states else None
|
encoder_states = () if output_hidden_states else None
|
||||||
all_attentions = () if output_attentions else None
|
all_attentions = () if output_attentions else None
|
||||||
|
@ -406,9 +443,15 @@ class IdeficsVisionEncoder(nn.Module):
|
||||||
encoder_states = encoder_states + (hidden_states,)
|
encoder_states = encoder_states + (hidden_states,)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [hidden_states, encoder_states, all_attentions]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
return BaseModelOutput(
|
return BaseModelOutput(
|
||||||
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
|
last_hidden_state=hidden_states,
|
||||||
|
hidden_states=encoder_states,
|
||||||
|
attentions=all_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -419,13 +462,19 @@ class IdeficsVisionTransformer(nn.Module):
|
||||||
self.config = config
|
self.config = config
|
||||||
embed_dim = config.hidden_size
|
embed_dim = config.hidden_size
|
||||||
|
|
||||||
self.embeddings = IdeficsVisionEmbeddings(prefix=f"{prefix}.embeddings", config=config, weights=weights)
|
self.embeddings = IdeficsVisionEmbeddings(
|
||||||
|
prefix=f"{prefix}.embeddings", config=config, weights=weights
|
||||||
|
)
|
||||||
self.pre_layrnorm = nn.LayerNorm.load(
|
self.pre_layrnorm = nn.LayerNorm.load(
|
||||||
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
|
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
self.encoder = IdeficsVisionEncoder(prefix=prefix, config=config, weights=weights)
|
self.encoder = IdeficsVisionEncoder(
|
||||||
|
prefix=prefix, config=config, weights=weights
|
||||||
|
)
|
||||||
self.post_layernorm = nn.LayerNorm.load(
|
self.post_layernorm = nn.LayerNorm.load(
|
||||||
prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps
|
prefix=f"{prefix}.post_layernorm",
|
||||||
|
weights=weights,
|
||||||
|
eps=config.layer_norm_eps,
|
||||||
)
|
)
|
||||||
|
|
||||||
# copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
|
# copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
|
||||||
|
@ -440,11 +489,19 @@ class IdeficsVisionTransformer(nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = (
|
||||||
output_hidden_states = (
|
output_attentions
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
)
|
)
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
raise ValueError("You have to specify pixel_values")
|
raise ValueError("You have to specify pixel_values")
|
||||||
|
|
|
@ -49,7 +49,10 @@ from text_generation_server.utils.layers import (
|
||||||
|
|
||||||
|
|
||||||
CUSTOM_KERNELS_ENABLED = False
|
CUSTOM_KERNELS_ENABLED = False
|
||||||
if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
|
if (
|
||||||
|
torch.cuda.is_available()
|
||||||
|
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
from custom_kernels import fused_attention_cuda
|
from custom_kernels import fused_attention_cuda
|
||||||
|
|
||||||
|
|
|
@ -1005,9 +1005,12 @@ class FlashCausalLM(Model):
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text, _, _ = self.decode_token(
|
output_text, _, _ = self.decode_token(
|
||||||
all_input_ids,
|
all_input_ids,
|
||||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
prefix_offset=len(all_input_ids)
|
||||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
- stopping_criteria.current_tokens
|
||||||
skip_special_tokens=True
|
- 1,
|
||||||
|
read_offset=len(all_input_ids)
|
||||||
|
- stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
generated_text = GeneratedText(
|
generated_text = GeneratedText(
|
||||||
output_text,
|
output_text,
|
||||||
|
|
|
@ -8,7 +8,13 @@ import re
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin
|
from transformers import (
|
||||||
|
AutoProcessor,
|
||||||
|
AutoTokenizer,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
PreTrainedTokenizerBase,
|
||||||
|
ProcessorMixin,
|
||||||
|
)
|
||||||
from typing import Optional, Tuple, List, Type, Dict
|
from typing import Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
|
@ -23,7 +29,8 @@ from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sam
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
IMAGES = re.compile(r'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)')
|
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
|
||||||
|
|
||||||
|
|
||||||
def split(string):
|
def split(string):
|
||||||
parts = []
|
parts = []
|
||||||
|
@ -41,6 +48,7 @@ def split(string):
|
||||||
|
|
||||||
return parts
|
return parts
|
||||||
|
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -94,7 +102,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
cls,
|
cls,
|
||||||
pb: generate_pb2.Batch,
|
pb: generate_pb2.Batch,
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
tokenizer: PreTrainedTokenizerBase,
|
||||||
processor: ProcessorMixin, # Hack
|
processor: ProcessorMixin, # Hack
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> "IdeficsCausalLMBatch":
|
) -> "IdeficsCausalLMBatch":
|
||||||
|
@ -137,12 +145,16 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
padding=True,
|
padding=True,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=max_truncation,
|
max_length=max_truncation,
|
||||||
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
|
||||||
).to(device)
|
).to(device)
|
||||||
for _ in pb.requests:
|
for _ in pb.requests:
|
||||||
input_len = tokenized_inputs["input_ids"].shape[1]
|
input_len = tokenized_inputs["input_ids"].shape[1]
|
||||||
prefix_offsets.append(input_len - 5) # To decode without potential fallbacks errors
|
prefix_offsets.append(
|
||||||
read_offsets.append(input_len) # To decode without potential fallbacks errors
|
input_len - 5
|
||||||
|
) # To decode without potential fallbacks errors
|
||||||
|
read_offsets.append(
|
||||||
|
input_len
|
||||||
|
) # To decode without potential fallbacks errors
|
||||||
|
|
||||||
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
input_lengths = tokenized_inputs["attention_mask"].sum(1)
|
||||||
max_input_length = input_lengths.max()
|
max_input_length = input_lengths.max()
|
||||||
|
@ -158,14 +170,21 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
|
||||||
# Do the same for image_attention_mask
|
# Do the same for image_attention_mask
|
||||||
image_attention_mask = input_ids.new_zeros(
|
image_attention_mask = input_ids.new_zeros(
|
||||||
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1))
|
(
|
||||||
|
pb.size,
|
||||||
|
max_input_length + padding_right_offset,
|
||||||
|
tokenized_inputs["pixel_values"].size(1),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
image_attention_mask[:, :max_input_length, :] = tokenized_inputs["image_attention_mask"]
|
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
|
||||||
|
"image_attention_mask"
|
||||||
|
]
|
||||||
|
|
||||||
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
|
||||||
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
|
||||||
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
|
all_input_ids = tokenized_inputs["input_ids"].T.split(
|
||||||
|
1, dim=1
|
||||||
|
) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
|
||||||
|
|
||||||
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
|
||||||
|
|
||||||
|
@ -259,7 +278,7 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
self.image_attention_mask.shape[1] - self.padding_right_offset
|
self.image_attention_mask.shape[1] - self.padding_right_offset
|
||||||
)
|
)
|
||||||
+ new_padding_right_offset,
|
+ new_padding_right_offset,
|
||||||
:
|
:,
|
||||||
]
|
]
|
||||||
if self.image_hidden_states is None:
|
if self.image_hidden_states is None:
|
||||||
image_hidden_states = None
|
image_hidden_states = None
|
||||||
|
@ -308,7 +327,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@tracer.start_as_current_span("concatenate")
|
@tracer.start_as_current_span("concatenate")
|
||||||
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
|
def concatenate(
|
||||||
|
cls, batches: List["IdeficsCausalLMBatch"]
|
||||||
|
) -> "IdeficsCausalLMBatch":
|
||||||
# It adds new requests to the batch
|
# It adds new requests to the batch
|
||||||
# Used for padding
|
# Used for padding
|
||||||
total_batch_size = 0
|
total_batch_size = 0
|
||||||
|
@ -383,12 +404,20 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
|
|
||||||
curr_batch_max_num_images = batch.pixel_values.size(1)
|
curr_batch_max_num_images = batch.pixel_values.size(1)
|
||||||
if pixel_values is None:
|
if pixel_values is None:
|
||||||
pixel_values = batch.pixel_values.new_zeros((total_batch_size, max_num_images, 3, 224, 224))
|
pixel_values = batch.pixel_values.new_zeros(
|
||||||
pixel_values[start_index:end_index, :curr_batch_max_num_images] = batch.pixel_values
|
(total_batch_size, max_num_images, 3, 224, 224)
|
||||||
|
)
|
||||||
|
pixel_values[
|
||||||
|
start_index:end_index, :curr_batch_max_num_images
|
||||||
|
] = batch.pixel_values
|
||||||
|
|
||||||
if image_attention_mask is None:
|
if image_attention_mask is None:
|
||||||
image_attention_mask = batch.image_attention_mask.new_zeros(
|
image_attention_mask = batch.image_attention_mask.new_zeros(
|
||||||
(total_batch_size, max_input_length + padding_right_offset, max_num_images)
|
(
|
||||||
|
total_batch_size,
|
||||||
|
max_input_length + padding_right_offset,
|
||||||
|
max_num_images,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to slice the attention mask to remove padding from previous steps
|
# We need to slice the attention mask to remove padding from previous steps
|
||||||
|
@ -409,11 +438,9 @@ class IdeficsCausalLMBatch(Batch):
|
||||||
image_attention_mask[
|
image_attention_mask[
|
||||||
start_index:end_index,
|
start_index:end_index,
|
||||||
left_offset:-padding_right_offset,
|
left_offset:-padding_right_offset,
|
||||||
:curr_batch_max_num_images
|
:curr_batch_max_num_images,
|
||||||
] = batch.image_attention_mask[
|
] = batch.image_attention_mask[
|
||||||
:,
|
:, batch_left_offset : -batch.padding_right_offset, :
|
||||||
batch_left_offset : - batch.padding_right_offset,
|
|
||||||
:
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Create empty tensor
|
# Create empty tensor
|
||||||
|
@ -550,7 +577,9 @@ class IdeficsCausalLM(Model):
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
):
|
):
|
||||||
from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text
|
from text_generation_server.models.custom_modeling.idefics_modeling import (
|
||||||
|
IdeficsForVisionText2Text,
|
||||||
|
)
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
|
@ -650,9 +679,13 @@ class IdeficsCausalLM(Model):
|
||||||
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
|
||||||
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
# token need to attend to the encoder hidden states (i.e. the vision encoder)
|
||||||
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
|
||||||
image_attention_mask = batch.image_attention_mask[:, -(batch.padding_right_offset+1)].unsqueeze(1)
|
image_attention_mask = batch.image_attention_mask[
|
||||||
|
:, -(batch.padding_right_offset + 1)
|
||||||
|
].unsqueeze(1)
|
||||||
else:
|
else:
|
||||||
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
|
image_attention_mask = batch.image_attention_mask[
|
||||||
|
:, : -batch.padding_right_offset
|
||||||
|
]
|
||||||
|
|
||||||
logits, past, image_hidden_states = self.forward(
|
logits, past, image_hidden_states = self.forward(
|
||||||
input_ids=batch.input_ids,
|
input_ids=batch.input_ids,
|
||||||
|
@ -725,9 +758,12 @@ class IdeficsCausalLM(Model):
|
||||||
# Decode generated tokens
|
# Decode generated tokens
|
||||||
output_text, _, _ = self.decode_token(
|
output_text, _, _ = self.decode_token(
|
||||||
all_input_ids[:, 0],
|
all_input_ids[:, 0],
|
||||||
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
|
prefix_offset=len(all_input_ids)
|
||||||
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
|
- stopping_criteria.current_tokens
|
||||||
skip_special_tokens=True
|
- 1,
|
||||||
|
read_offset=len(all_input_ids)
|
||||||
|
- stopping_criteria.current_tokens,
|
||||||
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
# Get seed
|
# Get seed
|
||||||
if isinstance(next_token_chooser.choice, Sampling):
|
if isinstance(next_token_chooser.choice, Sampling):
|
||||||
|
@ -761,7 +797,7 @@ class IdeficsCausalLM(Model):
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
|
|
||||||
top_tokens=None
|
top_tokens = None
|
||||||
|
|
||||||
generation = Generation(
|
generation = Generation(
|
||||||
request.id,
|
request.id,
|
||||||
|
@ -771,7 +807,7 @@ class IdeficsCausalLM(Model):
|
||||||
next_token_text,
|
next_token_text,
|
||||||
next_token_id_squeezed.item() in self.all_special_ids,
|
next_token_id_squeezed.item() in self.all_special_ids,
|
||||||
generated_text,
|
generated_text,
|
||||||
top_tokens
|
top_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
@ -793,7 +829,9 @@ class IdeficsCausalLM(Model):
|
||||||
|
|
||||||
# Update attention_mask as we added a new token to input_ids
|
# Update attention_mask as we added a new token to input_ids
|
||||||
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
batch.attention_mask[:, -batch.padding_right_offset] = 1
|
||||||
batch.image_attention_mask[:, -batch.padding_right_offset, :] = batch.image_attention_mask[:, -(batch.padding_right_offset+1), :]
|
batch.image_attention_mask[
|
||||||
|
:, -batch.padding_right_offset, :
|
||||||
|
] = batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
|
||||||
# Decrease right offset
|
# Decrease right offset
|
||||||
batch.padding_right_offset -= 1
|
batch.padding_right_offset -= 1
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,8 @@ class Model(ABC):
|
||||||
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
# The prefix text is necessary only to defeat cleanup algorithms in the decode
|
||||||
# which decide to add a space or not depending on the surrounding ids.
|
# which decide to add a space or not depending on the surrounding ids.
|
||||||
prefix_text = self.tokenizer.decode(
|
prefix_text = self.tokenizer.decode(
|
||||||
all_input_ids[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens
|
all_input_ids[prefix_offset:read_offset],
|
||||||
|
skip_special_tokens=skip_special_tokens,
|
||||||
)
|
)
|
||||||
new_text = self.tokenizer.decode(
|
new_text = self.tokenizer.decode(
|
||||||
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
|
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens
|
||||||
|
|
|
@ -712,9 +712,11 @@ class Seq2SeqLM(Model):
|
||||||
# Decode all tokens
|
# Decode all tokens
|
||||||
output_text, _, _ = self.decode_token(
|
output_text, _, _ = self.decode_token(
|
||||||
all_decoder_input_ids,
|
all_decoder_input_ids,
|
||||||
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
|
prefix_offset=len(all_decoder_input_ids)
|
||||||
|
- decoder_input_length
|
||||||
|
- 1,
|
||||||
read_offset=len(all_decoder_input_ids) - decoder_input_length,
|
read_offset=len(all_decoder_input_ids) - decoder_input_length,
|
||||||
skip_special_tokens=True
|
skip_special_tokens=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get seed
|
# Get seed
|
||||||
|
|
|
@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
|
||||||
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
|
||||||
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
||||||
|
|
||||||
|
|
||||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||||
self.cache = cache
|
self.cache = cache
|
||||||
|
@ -26,7 +27,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
# Force inference mode for the lifetime of TextGenerationService
|
# Force inference mode for the lifetime of TextGenerationService
|
||||||
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
|
||||||
|
|
||||||
|
|
||||||
async def Info(self, request, context):
|
async def Info(self, request, context):
|
||||||
return self.model.info
|
return self.model.info
|
||||||
|
|
||||||
|
@ -55,9 +55,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||||
|
|
||||||
async def Warmup(self, request, context):
|
async def Warmup(self, request, context):
|
||||||
if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call
|
if (
|
||||||
|
self.model.batch_type == IdeficsCausalLMBatch
|
||||||
|
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device
|
request.batch,
|
||||||
|
self.model.tokenizer,
|
||||||
|
self.model.processor,
|
||||||
|
self.model.dtype,
|
||||||
|
self.model.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
|
@ -70,9 +76,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def Prefill(self, request, context):
|
async def Prefill(self, request, context):
|
||||||
if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call
|
if (
|
||||||
|
self.model.batch_type == IdeficsCausalLMBatch
|
||||||
|
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device
|
request.batch,
|
||||||
|
self.model.tokenizer,
|
||||||
|
self.model.processor,
|
||||||
|
self.model.dtype,
|
||||||
|
self.model.device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
batch = self.model.batch_type.from_pb(
|
batch = self.model.batch_type.from_pb(
|
||||||
|
|
|
@ -11,7 +11,7 @@ import awq_inference_engine # with CUDA kernels
|
||||||
# super().__init__()
|
# super().__init__()
|
||||||
# self.act = module
|
# self.act = module
|
||||||
# self.scales = nn.Parameter(scales.data)
|
# self.scales = nn.Parameter(scales.data)
|
||||||
#
|
#
|
||||||
# def forward(self, x):
|
# def forward(self, x):
|
||||||
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
|
||||||
|
|
||||||
|
@ -19,10 +19,10 @@ import awq_inference_engine # with CUDA kernels
|
||||||
class WQLinear(nn.Module):
|
class WQLinear(nn.Module):
|
||||||
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
|
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if w_bit not in [4]:
|
if w_bit not in [4]:
|
||||||
raise NotImplementedError("Only 4-bit are supported for now.")
|
raise NotImplementedError("Only 4-bit are supported for now.")
|
||||||
|
|
||||||
self.in_features = qweight.shape[0]
|
self.in_features = qweight.shape[0]
|
||||||
self.out_features = qweight.shape[1] * 32 // w_bit
|
self.out_features = qweight.shape[1] * 32 // w_bit
|
||||||
|
|
||||||
|
@ -42,7 +42,9 @@ class WQLinear(nn.Module):
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
out_shape = x.shape[:-1] + (self.out_features, )
|
out_shape = x.shape[:-1] + (self.out_features,)
|
||||||
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
|
out = awq_inference_engine.gemm_forward_cuda(
|
||||||
|
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
|
||||||
|
)
|
||||||
out = out + self.bias if self.bias is not None else out
|
out = out + self.bias if self.bias is not None else out
|
||||||
return out.reshape(out_shape)
|
return out.reshape(out_shape)
|
||||||
|
|
|
@ -578,7 +578,9 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
|
||||||
return trainloader, valenc
|
return trainloader, valenc
|
||||||
|
|
||||||
|
|
||||||
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False):
|
def get_loaders(
|
||||||
|
name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False
|
||||||
|
):
|
||||||
if "wikitext2" in name:
|
if "wikitext2" in name:
|
||||||
return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
|
return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
|
||||||
if "ptb" in name:
|
if "ptb" in name:
|
||||||
|
@ -927,7 +929,7 @@ def quantize(
|
||||||
seed=seed,
|
seed=seed,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
seqlen=model.seqlen,
|
seqlen=model.seqlen,
|
||||||
trust_remote_code=trust_remote_code
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
tick = time.time()
|
tick = time.time()
|
||||||
|
|
|
@ -22,7 +22,7 @@ from text_generation_server.utils.gptq.quant_linear import QuantLinear
|
||||||
|
|
||||||
|
|
||||||
HAS_AWQ = True
|
HAS_AWQ = True
|
||||||
try:
|
try:
|
||||||
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
|
||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_AWQ = False
|
HAS_AWQ = False
|
||||||
|
@ -36,17 +36,19 @@ CAN_EXLLAMA = major >= 8
|
||||||
if os.getenv("DISABLE_EXLLAMA") == "True":
|
if os.getenv("DISABLE_EXLLAMA") == "True":
|
||||||
HAS_EXLLAMA = False
|
HAS_EXLLAMA = False
|
||||||
elif CAN_EXLLAMA:
|
elif CAN_EXLLAMA:
|
||||||
try:
|
try:
|
||||||
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
|
||||||
HAS_EXLLAMA = True
|
|
||||||
except ImportError:
|
HAS_EXLLAMA = True
|
||||||
pass
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
HAS_EETQ = False
|
HAS_EETQ = False
|
||||||
try:
|
try:
|
||||||
from EETQ import quant_weights, w8_a16_gemm
|
from EETQ import quant_weights, w8_a16_gemm
|
||||||
|
|
||||||
HAS_EETQ = True
|
HAS_EETQ = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
@ -74,12 +76,18 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
|
||||||
ln.bias = None
|
ln.bias = None
|
||||||
return ln
|
return ln
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
|
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
bias = weights.get_tensor(f"{prefix}.bias")
|
bias = weights.get_tensor(f"{prefix}.bias")
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
|
conv2d = cls(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
|
||||||
conv2d.weight = nn.Parameter(weight)
|
conv2d.weight = nn.Parameter(weight)
|
||||||
conv2d.bias = nn.Parameter(bias)
|
conv2d.bias = nn.Parameter(bias)
|
||||||
|
@ -87,10 +95,17 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_conv2d_no_bias(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
|
def load_conv2d_no_bias(
|
||||||
|
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
|
||||||
|
):
|
||||||
weight = weights.get_tensor(f"{prefix}.weight")
|
weight = weights.get_tensor(f"{prefix}.weight")
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
|
conv2d = cls(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=out_channels,
|
||||||
|
kernel_size=kernel_size,
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
|
||||||
conv2d.weight = nn.Parameter(weight)
|
conv2d.weight = nn.Parameter(weight)
|
||||||
conv2d.bias = None
|
conv2d.bias = None
|
||||||
|
@ -215,7 +230,10 @@ class Linear4bit(nn.Module):
|
||||||
def __init__(self, weight, bias, quant_type):
|
def __init__(self, weight, bias, quant_type):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = Params4bit(
|
self.weight = Params4bit(
|
||||||
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
|
weight.data,
|
||||||
|
requires_grad=False,
|
||||||
|
compress_statistics=True,
|
||||||
|
quant_type=quant_type,
|
||||||
)
|
)
|
||||||
self.compute_dtype = None
|
self.compute_dtype = None
|
||||||
self.weight.cuda(weight.device)
|
self.weight.cuda(weight.device)
|
||||||
|
@ -246,7 +264,10 @@ class Linear4bit(nn.Module):
|
||||||
|
|
||||||
@lru_cache(1)
|
@lru_cache(1)
|
||||||
def warn_deprecate_bnb():
|
def warn_deprecate_bnb():
|
||||||
logger.warning("Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce")
|
logger.warning(
|
||||||
|
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_linear(weight, bias, quantize):
|
def get_linear(weight, bias, quantize):
|
||||||
if quantize is None:
|
if quantize is None:
|
||||||
|
@ -255,7 +276,9 @@ def get_linear(weight, bias, quantize):
|
||||||
if HAS_EETQ:
|
if HAS_EETQ:
|
||||||
linear = EETQLinear(weight, bias)
|
linear = EETQLinear(weight, bias)
|
||||||
else:
|
else:
|
||||||
raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ")
|
raise ImportError(
|
||||||
|
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
|
||||||
|
)
|
||||||
elif quantize == "bitsandbytes":
|
elif quantize == "bitsandbytes":
|
||||||
warn_deprecate_bnb()
|
warn_deprecate_bnb()
|
||||||
linear = Linear8bitLt(
|
linear = Linear8bitLt(
|
||||||
|
@ -305,7 +328,14 @@ def get_linear(weight, bias, quantize):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
f"The passed weight is not `awq` compatible, loader needs to be updated."
|
||||||
)
|
)
|
||||||
linear = WQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None)
|
linear = WQLinear(
|
||||||
|
w_bit=bits,
|
||||||
|
group_size=groupsize,
|
||||||
|
qweight=qweight,
|
||||||
|
qzeros=qzeros,
|
||||||
|
scales=scales,
|
||||||
|
bias=bias is not None,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
|
||||||
return linear
|
return linear
|
||||||
|
@ -392,9 +422,7 @@ class TensorParallelColumnLinear(SuperLayer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
def load_qkv(cls, config, prefix: str, weights, bias: bool):
|
||||||
"""Specific method when the QKV was joined after the fact"""
|
"""Specific method when the QKV was joined after the fact"""
|
||||||
weight = weights.get_weights_col_packed_qkv(
|
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
|
||||||
prefix, quantize=config.quantize
|
|
||||||
)
|
|
||||||
if bias:
|
if bias:
|
||||||
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
raise NotImplementedError("packed_qkv only implemented for baichuan")
|
||||||
else:
|
else:
|
||||||
|
@ -530,14 +558,16 @@ try:
|
||||||
|
|
||||||
def _create_inv_freq(dim, base, device):
|
def _create_inv_freq(dim, base, device):
|
||||||
inv_freq = 1.0 / (
|
inv_freq = 1.0 / (
|
||||||
base
|
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
||||||
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
|
|
||||||
)
|
)
|
||||||
return inv_freq
|
return inv_freq
|
||||||
|
|
||||||
def _get_rope_config(config):
|
def _get_rope_config(config):
|
||||||
if os.getenv("ROPE_SCALING", None) is not None:
|
if os.getenv("ROPE_SCALING", None) is not None:
|
||||||
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
|
rope_scaling = {
|
||||||
|
"type": os.environ["ROPE_SCALING"],
|
||||||
|
"factor": float(os.environ["ROPE_FACTOR"]),
|
||||||
|
}
|
||||||
return rope_scaling
|
return rope_scaling
|
||||||
return getattr(config, "rope_scaling", None)
|
return getattr(config, "rope_scaling", None)
|
||||||
|
|
||||||
|
@ -563,9 +593,17 @@ try:
|
||||||
if rope_scaling["type"] == "linear":
|
if rope_scaling["type"] == "linear":
|
||||||
pass
|
pass
|
||||||
elif rope_scaling["type"] == "dynamic":
|
elif rope_scaling["type"] == "dynamic":
|
||||||
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
|
return DynamicPositionRotaryEmbedding(
|
||||||
|
dim=dim,
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
base=base,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
raise NotImplementedError(
|
||||||
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
|
)
|
||||||
return cls(inv_freq, scaling_factor)
|
return cls(inv_freq, scaling_factor)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -583,9 +621,17 @@ try:
|
||||||
if rope_scaling["type"] == "linear":
|
if rope_scaling["type"] == "linear":
|
||||||
pass
|
pass
|
||||||
elif rope_scaling["type"] == "dynamic":
|
elif rope_scaling["type"] == "dynamic":
|
||||||
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
|
return DynamicPositionRotaryEmbedding(
|
||||||
|
dim=2 * inv_freq.shape[0],
|
||||||
|
max_position_embeddings=config.max_position_embeddings,
|
||||||
|
base=10000.0,
|
||||||
|
device=inv_freq.device,
|
||||||
|
scaling_factor=scaling_factor,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
|
raise NotImplementedError(
|
||||||
|
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
|
||||||
|
)
|
||||||
return cls(inv_freq, scaling_factor)
|
return cls(inv_freq, scaling_factor)
|
||||||
|
|
||||||
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
def _update_cos_sin_cache(self, dtype, device, seqlen):
|
||||||
|
@ -645,8 +691,13 @@ try:
|
||||||
or self._cos_cached.dtype != dtype
|
or self._cos_cached.dtype != dtype
|
||||||
):
|
):
|
||||||
if seqlen > self.max_position_embeddings:
|
if seqlen > self.max_position_embeddings:
|
||||||
newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
|
newbase = self.base * (
|
||||||
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
|
(self.scaling_factor * seqlen / self.max_position_embeddings)
|
||||||
|
- (self.scaling_factor - 1)
|
||||||
|
) ** (self.dim / (self.dim - 2))
|
||||||
|
self.inv_freq = _create_inv_freq(
|
||||||
|
self.dim, newbase, self.inv_freq.device
|
||||||
|
)
|
||||||
self._seq_len_cached = seqlen
|
self._seq_len_cached = seqlen
|
||||||
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
|
||||||
# Don't do einsum, it converts fp32 to fp16
|
# Don't do einsum, it converts fp32 to fp16
|
||||||
|
@ -656,6 +707,5 @@ try:
|
||||||
self._cos_cached = torch.cos(freqs).to(dtype)
|
self._cos_cached = torch.cos(freqs).to(dtype)
|
||||||
self._sin_cached = torch.sin(freqs).to(dtype)
|
self._sin_cached = torch.sin(freqs).to(dtype)
|
||||||
|
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -6,6 +6,7 @@ import torch
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
|
||||||
|
|
||||||
|
|
||||||
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
|
|
||||||
|
@ -33,7 +34,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||||
base_model_id = model.peft_config["default"].base_model_name_or_path
|
base_model_id = model.peft_config["default"].base_model_name_or_path
|
||||||
|
|
||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
os.makedirs(model_id, exist_ok=True)
|
os.makedirs(model_id, exist_ok=True)
|
||||||
cache_dir = model_id
|
cache_dir = model_id
|
||||||
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
logger.info(f"Saving the newly created merged model to {cache_dir}")
|
||||||
|
@ -41,6 +42,3 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
|
||||||
model.save_pretrained(cache_dir, safe_serialization=True)
|
model.save_pretrained(cache_dir, safe_serialization=True)
|
||||||
model.config.save_pretrained(cache_dir)
|
model.config.save_pretrained(cache_dir)
|
||||||
tokenizer.save_pretrained(cache_dir)
|
tokenizer.save_pretrained(cache_dir)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -363,7 +363,7 @@ def batch_top_tokens(
|
||||||
# Find the new "fuzzy" top n values
|
# Find the new "fuzzy" top n values
|
||||||
top_n_indices = (logprobs >= nth_highest).nonzero()
|
top_n_indices = (logprobs >= nth_highest).nonzero()
|
||||||
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
|
||||||
|
|
||||||
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
|
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
|
||||||
# Take a new topk for these new max n values
|
# Take a new topk for these new max n values
|
||||||
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
|
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)
|
||||||
|
|
|
@ -62,7 +62,7 @@ class Weights:
|
||||||
def get_shape(self, tensor_name: str):
|
def get_shape(self, tensor_name: str):
|
||||||
return self._get_slice(tensor_name).get_shape()
|
return self._get_slice(tensor_name).get_shape()
|
||||||
|
|
||||||
def get_tensor(self, tensor_name: str, to_device = True):
|
def get_tensor(self, tensor_name: str, to_device=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
tensor = f.get_tensor(tensor_name)
|
tensor = f.get_tensor(tensor_name)
|
||||||
|
@ -110,7 +110,6 @@ class Weights:
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
return self.get_partial_sharded(tensor_name, dim)
|
return self.get_partial_sharded(tensor_name, dim)
|
||||||
|
|
||||||
|
|
||||||
def _get_qweight(self, name: str):
|
def _get_qweight(self, name: str):
|
||||||
slice_ = self._get_slice(name)
|
slice_ = self._get_slice(name)
|
||||||
total_size = slice_.get_shape()[1]
|
total_size = slice_.get_shape()[1]
|
||||||
|
@ -119,14 +118,16 @@ class Weights:
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
|
assert (
|
||||||
|
single_size % world_size == 0
|
||||||
|
), f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
|
||||||
block_size = single_size // world_size
|
block_size = single_size // world_size
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
q = slice_[:, start:stop]
|
q = slice_[:, start:stop]
|
||||||
k = slice_[:, start+single_size:stop+single_size]
|
k = slice_[:, start + single_size : stop + single_size]
|
||||||
v = slice_[:, start+2*single_size:stop+2*single_size]
|
v = slice_[:, start + 2 * single_size : stop + 2 * single_size]
|
||||||
weight = torch.cat([q,k,v], dim=1)
|
weight = torch.cat([q, k, v], dim=1)
|
||||||
weight = weight.to(device=self.device)
|
weight = weight.to(device=self.device)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
@ -137,14 +138,14 @@ class Weights:
|
||||||
"""
|
"""
|
||||||
if quantize in ["gptq", "awq"]:
|
if quantize in ["gptq", "awq"]:
|
||||||
try:
|
try:
|
||||||
qweight = self._get_qweight(f"{prefix}.qweight")
|
qweight = self._get_qweight(f"{prefix}.qweight")
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
|
||||||
)
|
)
|
||||||
|
|
||||||
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
qzeros = self._get_qweight(f"{prefix}.qzeros")
|
||||||
scales = self._get_qweight(f"{prefix}.scales")
|
scales = self._get_qweight(f"{prefix}.scales")
|
||||||
scales = scales.to(dtype=self.dtype)
|
scales = scales.to(dtype=self.dtype)
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
g_idx = self.get_tensor(f"{prefix}.g_idx")
|
||||||
|
@ -154,21 +155,23 @@ class Weights:
|
||||||
bits, groupsize = self._get_gptq_params()
|
bits, groupsize = self._get_gptq_params()
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
|
||||||
else:
|
else:
|
||||||
slice_ = self._get_slice(f"{prefix}.weight")
|
slice_ = self._get_slice(f"{prefix}.weight")
|
||||||
total_size = slice_.get_shape()[0]
|
total_size = slice_.get_shape()[0]
|
||||||
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
|
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
|
||||||
single_size = total_size // 3
|
single_size = total_size // 3
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
|
||||||
assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards"
|
assert (
|
||||||
|
single_size % world_size == 0
|
||||||
|
), f"Prepacked qkv cannot be sharded across {world_size} shards"
|
||||||
block_size = single_size // world_size
|
block_size = single_size // world_size
|
||||||
start = rank * block_size
|
start = rank * block_size
|
||||||
stop = (rank + 1) * block_size
|
stop = (rank + 1) * block_size
|
||||||
q = slice_[start:stop]
|
q = slice_[start:stop]
|
||||||
k = slice_[start+single_size:stop+single_size]
|
k = slice_[start + single_size : stop + single_size]
|
||||||
v = slice_[start+2*single_size:stop+2*single_size]
|
v = slice_[start + 2 * single_size : stop + 2 * single_size]
|
||||||
weight = torch.cat([q,k,v], dim=0)
|
weight = torch.cat([q, k, v], dim=0)
|
||||||
weight = weight.to(device=self.device)
|
weight = weight.to(device=self.device)
|
||||||
weight = weight.to(dtype=self.dtype)
|
weight = weight.to(dtype=self.dtype)
|
||||||
return weight
|
return weight
|
||||||
|
@ -205,7 +208,7 @@ class Weights:
|
||||||
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
||||||
weight = torch.cat(w, dim=dim)
|
weight = torch.cat(w, dim=dim)
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
def get_tensor_shard(self, var, dim):
|
def get_tensor_shard(self, var, dim):
|
||||||
world_size = self.process_group.size()
|
world_size = self.process_group.size()
|
||||||
rank = self.process_group.rank()
|
rank = self.process_group.rank()
|
||||||
|
@ -220,7 +223,7 @@ class Weights:
|
||||||
raise NotImplementedError("Let's make that generic when needed")
|
raise NotImplementedError("Let's make that generic when needed")
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_multi_weights_row(self, prefix: str, quantize: str):
|
def get_multi_weights_row(self, prefix: str, quantize: str):
|
||||||
if quantize == "gptq":
|
if quantize == "gptq":
|
||||||
|
@ -303,7 +306,7 @@ class Weights:
|
||||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||||
g_idx = None
|
g_idx = None
|
||||||
use_exllama = False
|
use_exllama = False
|
||||||
|
|
||||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||||
else:
|
else:
|
||||||
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
weight = self.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
|
|
Loading…
Reference in New Issue