feat: format code (#1070)

This commit is contained in:
OlivierDehaene 2023-09-27 12:22:09 +02:00 committed by GitHub
parent b32e9ce9d5
commit 47954b81e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 772 additions and 298 deletions

View File

@ -137,7 +137,7 @@ class Client:
typical_p=typical_p,
watermark=watermark,
decoder_input_details=decoder_input_details,
top_n_tokens=top_n_tokens
top_n_tokens=top_n_tokens,
)
request = Request(inputs=prompt, stream=False, parameters=parameters)

View File

@ -133,7 +133,9 @@ class Request(BaseModel):
and parameters.best_of > 1
and field_value
):
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
)
return field_value

View File

@ -3,7 +3,11 @@ import pytest
@pytest.fixture(scope="module")
def flash_llama_awq_handle(launcher):
with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=1, quantize="awq") as handle:
with launcher(
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq",
num_shard=1,
quantize="awq",
) as handle:
yield handle
@ -12,6 +16,7 @@ async def flash_llama_awq(flash_llama_awq_handle):
await flash_llama_awq_handle.health(300)
return flash_llama_awq_handle.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
@ -20,11 +25,13 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot):
)
assert response.details.generated_tokens == 10
assert response.generated_text == "\nWhat is the difference between Deep Learning and Machine"
assert (
response.generated_text
== "\nWhat is the difference between Deep Learning and Machine"
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
@ -49,16 +56,18 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot):
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_load(
flash_llama_awq, generate_load, response_snapshot
):
async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot):
responses = await generate_load(
flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4
)
assert len(responses) == 4
assert all([r.generated_text == "\nWhat is the difference between Deep Learning and Machine" for r in responses])
assert all(
[
r.generated_text
== "\nWhat is the difference between Deep Learning and Machine"
for r in responses
]
)
assert responses == response_snapshot

View File

@ -1,15 +1,22 @@
import pytest
@pytest.fixture(scope="module")
def flash_llama_awq_handle_sharded(launcher):
with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=2, quantize="awq") as handle:
with launcher(
"abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq",
num_shard=2,
quantize="awq",
) as handle:
yield handle
@pytest.fixture(scope="module")
async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded):
await flash_llama_awq_handle_sharded.health(300)
return flash_llama_awq_handle_sharded.client
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot):
@ -18,9 +25,13 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho
)
assert response.details.generated_tokens == 10
assert response.generated_text == "\nWhat is the difference between Deep Learning and Machine"
assert (
response.generated_text
== "\nWhat is the difference between Deep Learning and Machine"
)
assert response == response_snapshot
@pytest.mark.asyncio
@pytest.mark.private
async def test_flash_llama_awq_load_sharded(
@ -31,6 +42,12 @@ async def test_flash_llama_awq_load_sharded(
)
assert len(responses) == 4
assert all([r.generated_text == "\nWhat is the difference between Deep Learning and Machine" for r in responses])
assert all(
[
r.generated_text
== "\nWhat is the difference between Deep Learning and Machine"
for r in responses
]
)
assert responses == response_snapshot

View File

@ -3,9 +3,7 @@ import pytest
@pytest.fixture(scope="module")
def idefics_handle(launcher):
with launcher(
"HuggingFaceM4/idefics-9b-instruct", num_shard=2
) as handle:
with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2) as handle:
yield handle

View File

@ -45,12 +45,15 @@ def test_stopping_criteria_max():
assert criteria(1, "") == (False, None)
assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH)
def test_batch_top_tokens():
top_n_tokens = [0, 2, 3, 4, 5]
top_n_tokens_tensor = torch.tensor(top_n_tokens)
inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5)
inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs)
topn_tok_ids, topn_tok_logprobs = batch_top_tokens(
top_n_tokens, top_n_tokens_tensor, inp_logprobs
)
assert topn_tok_ids[0] == []
assert topn_tok_ids[1] == [0, 3]

View File

@ -125,8 +125,12 @@ def download_weights(
if not is_local_model:
try:
adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json")
utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code)
adapter_config_filename = hf_hub_download(
model_id, revision=revision, filename="adapter_config.json"
)
utils.download_and_unload_peft(
model_id, revision, trust_remote_code=trust_remote_code
)
is_local_model = True
utils.weight_files(model_id, revision, extension)
return
@ -179,11 +183,12 @@ def download_weights(
import transformers
import json
if is_local_model:
config_filename = os.path.join(model_id, "config.json")
else:
config_filename = hf_hub_download(model_id, revision=revision, filename="config.json")
config_filename = hf_hub_download(
model_id, revision=revision, filename="config.json"
)
with open(config_filename, "r") as f:
config = json.load(f)
architecture = config["architectures"][0]

View File

@ -153,7 +153,11 @@ def get_model(
)
elif model_type == "mpt":
return MPTSharded(
model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == "gpt_neox":
@ -252,13 +256,13 @@ def get_model(
)
elif model_type == "idefics":
if FLASH_ATTENTION:
return IDEFICSSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
return IDEFICSSharded(
model_id,
revision,
quantize=quantize,
dtype=dtype,
trust_remote_code=trust_remote_code,
)
else:
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics"))
@ -269,13 +273,9 @@ def get_model(
"gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
if quantize == "awq":
raise ValueError(
"awq quantization is not supported for AutoModel"
)
raise ValueError("awq quantization is not supported for AutoModel")
elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"):
raise ValueError(
"4bit quantization is not supported for AutoModel"
)
raise ValueError("4bit quantization is not supported for AutoModel")
if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
return CausalLM(
model_id,

View File

@ -643,9 +643,12 @@ class CausalLM(Model):
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):

View File

@ -40,7 +40,10 @@ from text_generation_server.utils.layers import (
)
CUSTOM_KERNELS_ENABLED = False
if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
if (
torch.cuda.is_available()
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
):
try:
from custom_kernels import fused_bloom_attention_cuda

View File

@ -169,6 +169,7 @@ def load_attention(config, prefix, weights):
bias=False,
)
def _load_gqa(config, prefix: str, weights):
assert config.hidden_size % config.num_attention_heads == 0
assert config.num_attention_heads % weights.process_group.size() == 0
@ -211,7 +212,10 @@ class FlashLlamaAttention(torch.nn.Module):
# config=config, prefix=f"{prefix}.rotary_emb", weights=weights
# )
self.rotary_emb = PositionRotaryEmbedding.static(
config=config, dim=self.head_size, base=config.rope_theta, device=weights.device
config=config,
dim=self.head_size,
base=config.rope_theta,
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5

View File

@ -20,7 +20,12 @@ import numpy as np
from PIL import Image
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
from transformers.image_transforms import resize, to_channel_dimension_format, rescale, normalize
from transformers.image_transforms import (
resize,
to_channel_dimension_format,
rescale,
normalize,
)
from transformers.image_utils import (
ChannelDimension,
ImageInput,
@ -121,7 +126,11 @@ class IdeficsImageProcessor(BaseImageProcessor):
a PyTorch tensor of the processed images
"""
image_size = image_size if image_size is not None else self.image_size
image_num_channels = image_num_channels if image_num_channels is not None else self.image_num_channels
image_num_channels = (
image_num_channels
if image_num_channels is not None
else self.image_num_channels
)
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
size = (image_size, image_size)
@ -160,9 +169,13 @@ class IdeficsImageProcessor(BaseImageProcessor):
images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images]
images = [self.rescale(image=image, scale=1 / 255) for image in images]
images = [self.normalize(x, mean=image_mean, std=image_std) for x in images]
images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images]
images = [
to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images
]
# TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available
images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"]
images = BatchFeature(
data={"pixel_values": images}, tensor_type=TensorType.PYTORCH
)["pixel_values"]
return images
@ -185,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
response.raise_for_status()
return Image.open(BytesIO(response.content))
else:
raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}")
raise ValueError(
f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}"
)
def rescale(
self,
@ -255,10 +270,9 @@ class IdeficsImageProcessor(BaseImageProcessor):
`np.ndarray`: The normalized image.
"""
# TODO 4.32
return normalize(
image, mean=mean, std=std, data_format=data_format, **kwargs
)
return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs)
import transformers
transformers.IdeficsImageProcessor = IdeficsImageProcessor

View File

@ -28,7 +28,11 @@ from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, dataclass
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
dataclass,
)
from transformers.modeling_utils import PretrainedConfig
from transformers.utils import (
add_start_docstrings,
@ -37,8 +41,12 @@ from transformers.utils import (
replace_return_docstrings,
)
from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig
from text_generation_server.models.custom_modeling.idefics_vision import IdeficsVisionTransformer
from text_generation_server.models.custom_modeling.idefics_perceiver import IdeficsPerceiverResampler
from text_generation_server.models.custom_modeling.idefics_vision import (
IdeficsVisionTransformer,
)
from text_generation_server.models.custom_modeling.idefics_perceiver import (
IdeficsPerceiverResampler,
)
from text_generation_server.utils.layers import (
TensorParallelColumnLinear,
TensorParallelEmbedding,
@ -49,10 +57,12 @@ from text_generation_server.utils.layers import (
)
import dropout_layer_norm
@dataclass
class BaseModelOutputWithPastImage(BaseModelOutputWithPast):
image_hidden_states: Optional[torch.FloatTensor] = None
@dataclass
class CausalLMOutputWithPastImage(CausalLMOutputWithPast):
image_hidden_states: Optional[torch.FloatTensor] = None
@ -78,25 +88,39 @@ def expand_inputs_for_generation(
**model_kwargs,
):
expanded_return_idx = (
torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device)
torch.arange(input_ids.shape[0])
.view(-1, 1)
.repeat(1, expand_size)
.view(-1)
.to(input_ids.device)
)
input_ids = input_ids.index_select(0, expanded_return_idx)
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx)
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx)
model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select(
model_kwargs["token_type_ids"] = token_type_ids.index_select(
0, expanded_return_idx
)
if attention_mask is not None:
model_kwargs["attention_mask"] = attention_mask.index_select(
0, expanded_return_idx
)
model_kwargs["image_attention_mask"] = model_kwargs[
"image_attention_mask"
].index_select(0, expanded_return_idx)
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(
0, expanded_return_idx
)
model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx)
if is_encoder_decoder:
if encoder_outputs is None:
raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
raise ValueError(
"If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined."
)
encoder_outputs[
"last_hidden_state"
] = encoder_outputs.last_hidden_state.index_select(
0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device)
)
model_kwargs["encoder_outputs"] = encoder_outputs
@ -120,14 +144,17 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder
# update token_type_ids with last value
if "token_type_ids" in model_kwargs:
token_type_ids = model_kwargs["token_type_ids"]
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
model_kwargs["token_type_ids"] = torch.cat(
[token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
)
# update attention masks
if not is_encoder_decoder:
if "attention_mask" in model_kwargs:
attention_mask = model_kwargs["attention_mask"]
model_kwargs["attention_mask"] = torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))],
dim=-1,
)
if "image_attention_mask" in model_kwargs:
image_attention_mask = model_kwargs["image_attention_mask"]
@ -180,8 +207,12 @@ def freeze_model(model, module_exceptions=[]):
}
module_exceptions_mapped = [mapping[m] for m in module_exceptions]
for module in model.modules():
if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]):
module.requires_grad_(True) # Explicitely setting it to true to avoid any mistakes
if module_exceptions and any(
[isinstance(module, t) for t in module_exceptions_mapped]
):
module.requires_grad_(
True
) # Explicitely setting it to true to avoid any mistakes
else:
module.requires_grad_(False)
return model
@ -195,15 +226,21 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module):
):
super().__init__()
self.num_embeddings = config.vocab_size
self.weight = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights)
self.additional_weight = nn.Parameter(weights.get_tensor(f"model.embed_tokens.additional_embedding.weight"))
self.weight = TensorParallelEmbedding(
prefix="model.embed_tokens", weights=weights
)
self.additional_weight = nn.Parameter(
weights.get_tensor(f"model.embed_tokens.additional_embedding.weight")
)
def forward(self, input_ids):
# Clone so that we don't modify the original input_ids later on
input_ids = input_ids.clone()
additional_vocab_indices = torch.where(input_ids >= self.num_embeddings)
input_ids_additional_vocab = input_ids[additional_vocab_indices]
additional_embeddings = torch.nn.functional.embedding(input_ids_additional_vocab - self.num_embeddings, self.additional_weight)
additional_embeddings = torch.nn.functional.embedding(
input_ids_additional_vocab - self.num_embeddings, self.additional_weight
)
# for successful lookup replace input_ids with 0, the results of these will be discarded anyway
input_ids[additional_vocab_indices] = 0
@ -234,7 +271,10 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
config=config, prefix="lm_head", weights=weights
)
self.additional_fc = FastLinear.load(
config=config, prefix="lm_head.additional_fc", weights=weights, bias=False,
config=config,
prefix="lm_head.additional_fc",
weights=weights,
bias=False,
)
def forward(self, input: torch.Tensor) -> torch.Tensor:
@ -257,7 +297,10 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module):
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
input_ids_shape: torch.Size,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
):
"""
Make causal mask used for bi-directional self-attention.
@ -269,8 +312,18 @@ def _make_causal_mask(
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
@ -284,7 +337,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
class IdeficsRMSNorm(nn.Module):
@ -346,7 +401,6 @@ class IdeficsRMSNorm(nn.Module):
if unwrap:
normed_hidden_states = normed_hidden_states.view(*shape)
return normed_hidden_states
@ -367,7 +421,10 @@ class IdeficsMLP(nn.Module):
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
config, prefix=f"{prefix}.down_proj", weights=weights, bias=False,
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.act_fn = ACT2FN[config.hidden_act]
@ -375,7 +432,9 @@ class IdeficsMLP(nn.Module):
gate_up_states = self.gate_up_proj(hidden_states)
shape = gate_up_states.shape
gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2)
return self.down_proj(self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1])
return self.down_proj(
self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1]
)
# this was adapted from LlamaAttention
@ -445,14 +504,22 @@ class IdeficsAttention(nn.Module):
self.qk_layer_norms = qk_layer_norms
if self.qk_layer_norms:
self.q_layer_norm = IdeficsRMSNorm(
prefix=f"{prefix}.q_layer_norm", weights=weights, eps=config.rms_norm_eps
)
prefix=f"{prefix}.q_layer_norm",
weights=weights,
eps=config.rms_norm_eps,
)
self.k_layer_norm = IdeficsRMSNorm(
prefix=f"{prefix}.q_layer_norm", weights=weights, eps=config.rms_norm_eps
)
prefix=f"{prefix}.q_layer_norm",
weights=weights,
eps=config.rms_norm_eps,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
@ -470,20 +537,42 @@ class IdeficsAttention(nn.Module):
bsz, q_len, _ = hidden_states.size()
if is_cross_attention:
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
query_states = self.q_proj(hidden_states).view(
bsz, q_len, self.num_heads, self.head_dim
) # .transpose(1, 2)
query_states = query_states.transpose(1, 2)
_, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len`
key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
(
_,
kv_len,
_,
) = (
key_value_states.size()
) # Note that, in this case, `kv_len` == `kv_seq_len`
key_states = (
self.k_proj(key_value_states)
.view(bsz, kv_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
value_states = (
self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
self.v_proj(key_value_states)
.view(bsz, kv_len, self.num_heads, self.head_dim)
.transpose(1, 2)
)
else:
qkv = self.qkv(hidden_states)
query_states, key_states, value_states = qkv.split(self.num_heads * self.head_dim, dim=2)
query_states, key_states, value_states = qkv.split(
self.num_heads * self.head_dim, dim=2
)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
) # .transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_heads, self.head_dim
) # . transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_heads, self.head_dim
) # .transpose(1, 2)
kv_seq_len = q_len
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
@ -493,10 +582,14 @@ class IdeficsAttention(nn.Module):
)
shape = query_states.shape
query_states = self.rotary_emb(query_states.view(-1, *shape[2:]), cos, sin).view(shape)
query_states = self.rotary_emb(
query_states.view(-1, *shape[2:]), cos, sin
).view(shape)
shape = key_states.shape
key_states = self.rotary_emb(key_states.reshape(-1, *shape[2:]), cos, sin).view(shape)
key_states = self.rotary_emb(
key_states.reshape(-1, *shape[2:]), cos, sin
).view(shape)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
@ -571,8 +664,14 @@ class IdeficsDecoderLayer(nn.Module):
prefix=f"{prefix}.mlp",
weights=weights,
)
self.input_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
self.post_attention_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps)
self.input_layernorm = IdeficsRMSNorm(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = IdeficsRMSNorm(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.dropout = config.dropout
def forward(
@ -583,7 +682,9 @@ class IdeficsDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@ -650,14 +751,22 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
prefix=f"{prefix}.mlp",
weights=weights,
)
self.input_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps)
self.post_attention_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps)
self.input_layernorm = IdeficsRMSNorm(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
)
self.post_attention_layernorm = IdeficsRMSNorm(
prefix=f"{prefix}.post_attention_layernorm",
weights=weights,
eps=config.rms_norm_eps,
)
self.config = config.dropout
self.act_cross_attn = nn.Tanh()
self.act_dense = nn.Tanh()
self.alpha_cross_attn = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_cross_attn"))
self.alpha_cross_attn = nn.Parameter(
weights.get_tensor(f"{prefix}.alpha_cross_attn")
)
self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense"))
if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")):
@ -673,7 +782,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
use_cache: Optional[bool] = False,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
no_images: Optional[bool] = False,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@ -695,7 +806,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
)
if past_key_value is not None:
raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.")
raise NotImplementedError(
"Past key value states are not implemented for Idefics cross attention module."
)
residual = hidden_states
@ -711,7 +824,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module):
# hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training)
# when there are no images the model is used in pure language mode
gate = 0 if no_images else 1
hidden_states = residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states
hidden_states = (
residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states
)
# Fully Connected
residual = hidden_states
@ -896,11 +1011,14 @@ class IdeficsModel(IdeficsPreTrainedModel):
self.gated_cross_attn_layers = nn.ModuleList(
[
IdeficsGatedCrossAttentionLayer(layer_id, config, weights)
for layer_id in range(num_cross_layers)]
for layer_id in range(num_cross_layers)
]
)
# self.gradient_checkpointing = False
self.norm = IdeficsRMSNorm(prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps)
self.norm = IdeficsRMSNorm(
prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps
)
# self.gradient_checkpointing = False
# Initialize weights and apply final processing
@ -932,7 +1050,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
# self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
def _prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
@ -946,11 +1066,13 @@ class IdeficsModel(IdeficsPreTrainedModel):
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
expanded_attn_mask = _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
).to(inputs_embeds.device)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
expanded_attn_mask
if combined_attention_mask is None
else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
@ -974,23 +1096,35 @@ class IdeficsModel(IdeficsPreTrainedModel):
) -> Union[Tuple, BaseModelOutputWithPastImage]:
device = input_ids.device if input_ids is not None else inputs_embeds.device
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
raise ValueError(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif input_ids is not None:
batch_size, seq_length = input_ids.shape
elif inputs_embeds is not None:
batch_size, seq_length, _ = inputs_embeds.shape
else:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
raise ValueError(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past = seq_length
past_key_values_length = 0
@ -1006,7 +1140,10 @@ class IdeficsModel(IdeficsPreTrainedModel):
elif position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
@ -1016,29 +1153,52 @@ class IdeficsModel(IdeficsPreTrainedModel):
if image_hidden_states is None:
if pixel_values is None and image_embeddings is None:
raise ValueError("Either pixel_values and image_embeddings have to be not-None.")
raise ValueError(
"Either pixel_values and image_embeddings have to be not-None."
)
elif pixel_values is not None and image_embeddings is not None:
raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time")
raise ValueError(
"You cannot specify both pixel_values and image_embeddings at the same time"
)
elif pixel_values is not None:
no_images = len(torch.nonzero(pixel_values)) == 0
pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility
pixel_values = pixel_values.to(
dtype=self.dtype, device=device
) # fp16 compatibility
batch_size, num_images = pixel_values.shape[:2]
pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:])
pixel_values = pixel_values.contiguous().view(
batch_size * num_images, *pixel_values.shape[2:]
)
# Get sequence from the vision encoder
image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state
image_hidden_states = self.vision_model(
pixel_values=pixel_values
).last_hidden_state
elif image_embeddings is not None:
batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size()
image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device)
image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size)
(
batch_size,
num_images,
image_seq_len,
image_hidden_size,
) = image_embeddings.size()
image_hidden_states = image_embeddings.to(
dtype=self.dtype, device=input_ids.device
)
image_hidden_states = image_hidden_states.view(
batch_size * num_images, image_seq_len, image_hidden_size
)
if self.config.use_resampler:
image_hidden_states = self.perceiver_resampler(image_hidden_states)
image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2)
image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size)
image_seq_len, image_hidden_size = image_hidden_states.size(
1
), image_hidden_states.size(2)
image_hidden_states = image_hidden_states.view(
batch_size, num_images * image_seq_len, image_hidden_size
)
else:
no_images = False
num_images = pixel_values.shape[1]
@ -1050,7 +1210,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
text_seq_len = image_attention_mask.size(1)
image_attention_mask = image_attention_mask.unsqueeze(-1)
image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len)
image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len)
image_attention_mask = image_attention_mask.view(
batch_size, text_seq_len, num_images * image_seq_len
)
image_batch_size, image_sequence_length, _ = image_hidden_states.size()
image_hidden_shape = (image_batch_size, image_sequence_length)
if image_attention_mask is None:
@ -1060,7 +1222,6 @@ class IdeficsModel(IdeficsPreTrainedModel):
# if list(image_attention_mask.shape) != [4, 1, 1024, 64]:
# raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}")
# if image_hidden_states is not None:
# else:
# image_attention_mask = None
@ -1070,10 +1231,15 @@ class IdeficsModel(IdeficsPreTrainedModel):
# embed positions
if attention_mask is None:
attention_mask = torch.ones(
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
(batch_size, seq_length_with_past),
dtype=torch.bool,
device=inputs_embeds.device,
)
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)
hidden_states = inputs_embeds
@ -1094,7 +1260,9 @@ class IdeficsModel(IdeficsPreTrainedModel):
if output_hidden_states:
all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
def vblock(
main_block,
@ -1194,7 +1362,11 @@ class IdeficsModel(IdeficsPreTrainedModel):
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return tuple(
v
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
if v is not None
)
return BaseModelOutputWithPastImage(
last_hidden_state=hidden_states,
past_key_values=next_cache,
@ -1230,7 +1402,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
inputs_embeds: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
image_embeddings: Optional[torch.FloatTensor] = None,
image_hidden_states: Optional[torch.FloatTensor] = None,
image_hidden_states: Optional[torch.FloatTensor] = None,
image_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
@ -1264,11 +1436,19 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
@ -1298,7 +1478,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
image_hidden_states=outputs.image_hidden_states
image_hidden_states=outputs.image_hidden_states,
)
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
@ -1316,12 +1496,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel):
return expand_inputs_for_generation(*args, **model_kwargs)
@staticmethod
def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False):
return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder)
def _update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False
):
return update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder
)
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
reordered_past += (
tuple(
past_state.index_select(0, beam_idx) for past_state in layer_past
),
)
return reordered_past

View File

@ -46,7 +46,8 @@ from text_generation_server.utils.layers import (
TensorParallelRowLinear,
)
EPS=1e-5
EPS = 1e-5
class IdeficsPerceiverResampler(nn.Module):
def __init__(
@ -78,7 +79,12 @@ class IdeficsPerceiverResampler(nn.Module):
"""
super().__init__()
self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents
self.embed_dim, self.n_heads, self.head_dim, self.n_latents = (
embed_dim,
n_heads,
head_dim,
n_latents,
)
self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver
# Create Latents for Perceiver
@ -107,14 +113,16 @@ class IdeficsPerceiverResampler(nn.Module):
prefix=f"{prefix}.blocks.{layer_id}.1",
intermediate_size=self.intermediate_dim,
config=config,
weights=weights
weights=weights,
),
]
)
for layer_id in range(depth)
]
)
self.layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS)
self.layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS
)
def forward(self, context: torch.Tensor) -> torch.Tensor:
"""Resample arbitrary length context & *compress* down to self.n_latents latent embeddings"""
@ -130,25 +138,34 @@ class IdeficsPerceiverResampler(nn.Module):
class IdeficsPerceiverAttention(nn.Module):
def __init__(self,
prefix,
config,
embed_dim: int,
n_heads: int,
head_dim: int,
qk_layer_norms: bool,
weights
) -> None:
def __init__(
self,
prefix,
config,
embed_dim: int,
n_heads: int,
head_dim: int,
qk_layer_norms: bool,
weights,
) -> None:
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
super().__init__()
self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim
self.qk_layer_norms = qk_layer_norms
# Normalization & Scaling
self.context_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS)
self.latents_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS)
self.context_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS
)
self.latents_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS
)
if self.qk_layer_norms:
self.q_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS)
self.k_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS)
self.q_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS
)
self.k_layer_norm = nn.LayerNorm.load(
prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS
)
self.qk_scale = self.head_dim**-0.5
@ -164,10 +181,10 @@ class IdeficsPerceiverAttention(nn.Module):
self.q_proj = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False
)
self.k_proj = TensorParallelColumnLinear.load(
self.k_proj = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False
)
self.v_proj = TensorParallelColumnLinear.load(
self.v_proj = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False
)
@ -202,7 +219,12 @@ class IdeficsPerceiverAttention(nn.Module):
# Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call)
# =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)]
# einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads)
q, k, v = [x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(1, 2) for x in (q, k, v)]
q, k, v = [
x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(
1, 2
)
for x in (q, k, v)
]
if self.qk_layer_norms:
q = self.q_layer_norm(q)
@ -219,25 +241,34 @@ class IdeficsPerceiverAttention(nn.Module):
class IdeficsMLP(nn.Module):
def __init__(self,
prefix,
intermediate_size,
config,
weights,
):
def __init__(
self,
prefix,
intermediate_size,
config,
weights,
):
"""Simple MLP block with intermediate_size and embedding size"""
super().__init__()
self.embed_dim = config.vision_config.embed_dim
self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS)
self.fc = TensorParallelColumnLinear.load(
config=config, prefix=f"{prefix}.fc", weights=weights, bias=False,
config=config,
prefix=f"{prefix}.fc",
weights=weights,
bias=False,
)
self.act = nn.ReLU()
self.c_proj = TensorParallelRowLinear.load(
config=config, prefix=f"{prefix}.c_proj", weights=weights, bias=False,
config=config,
prefix=f"{prefix}.c_proj",
weights=weights,
bias=False,
)
def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
def forward(
self, hidden_states: Optional[Tuple[torch.FloatTensor]]
) -> torch.FloatTensor:
hidden_states = self.ln(hidden_states)
hidden_states = self.fc(hidden_states)
hidden_states = self.act(hidden_states)

View File

@ -21,9 +21,16 @@ from urllib.parse import urlparse
from transformers.feature_extraction_utils import BatchFeature
from transformers.processing_utils import ProcessorMixin
from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy
from transformers.tokenization_utils_base import (
BatchEncoding,
PaddingStrategy,
TextInput,
TruncationStrategy,
)
from transformers.utils import TensorType, is_torch_available
from text_generation_server.models.custom_modeling.idefics_image_processing import IdeficsImageProcessor
from text_generation_server.models.custom_modeling.idefics_image_processing import (
IdeficsImageProcessor,
)
if is_torch_available():
@ -124,7 +131,14 @@ class IdeficsProcessor(ProcessorMixin):
image_processor_class = "IdeficsImageProcessor"
tokenizer_class = "LlamaTokenizerFast"
def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_utterance_token=None, **kwargs):
def __init__(
self,
image_processor,
tokenizer=None,
image_size=224,
add_end_of_utterance_token=None,
**kwargs,
):
if image_processor is None:
raise ValueError("You need to specify an `image_processor`.")
if tokenizer is None:
@ -142,7 +156,8 @@ class IdeficsProcessor(ProcessorMixin):
self.tokenizer_was_trained_with_end_of_utterance_token = (
True
if "<end_of_utterance>" in self.tokenizer.special_tokens_map.get("additional_special_tokens", [])
if "<end_of_utterance>"
in self.tokenizer.special_tokens_map.get("additional_special_tokens", [])
else False
)
@ -265,7 +280,9 @@ class IdeficsProcessor(ProcessorMixin):
# if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
if add_end_of_utterance_token is None:
add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
add_end_of_utterance_token = (
self.tokenizer_was_trained_with_end_of_utterance_token
)
# turn non-batched prompts into batched
if not any(isinstance(i, list) for i in prompts):
@ -358,10 +375,14 @@ class IdeficsProcessor(ProcessorMixin):
current_images = images[:local_max_num_images]
if len(current_images) > 0:
padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
padded_image_tensor = torch.zeros(
max_num_images, *current_images.size()[1:]
)
padded_image_tensor[: current_images.size(0)] = current_images
else:
padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims)
padded_image_tensor = torch.zeros(
max_num_images, *self.default_image_dims
)
output_images.append(padded_image_tensor)
output_input_ids.append(torch.tensor(padded_input_ids))
@ -373,14 +394,19 @@ class IdeficsProcessor(ProcessorMixin):
output_attention_masks = torch.stack(output_attention_masks)
if at_least_one_image:
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, self.tokenizer)
image_attention_mask, _ = image_attention_mask_for_packed_input_ids(
output_input_ids, self.tokenizer
)
image_attention_mask = incremental_to_binary_attention_mask(
image_attention_mask, num_classes=max_num_images
)
else:
# in full language mode we set the image mask to all-0s
image_attention_mask = torch.zeros(
output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool
output_input_ids.shape[0],
output_input_ids.shape[1],
1,
dtype=torch.bool,
)
return BatchFeature(

View File

@ -75,7 +75,9 @@ class IdeficsVisionEmbeddings(nn.Module):
self.image_size = config.image_size
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(weights.get_tensor(f"{prefix}.class_embedding"))
self.class_embedding = nn.Parameter(
weights.get_tensor(f"{prefix}.class_embedding")
)
self.patch_embedding = nn.Conv2d.load_no_bias(
prefix=f"{prefix}.patch_embedding",
@ -91,12 +93,16 @@ class IdeficsVisionEmbeddings(nn.Module):
self.position_embedding = TensorParallelEmbedding(
prefix="model.vision_model.embeddings.position_embedding", weights=weights
)
self.position_ids = torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
self.position_ids = (
torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device)
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = self.patch_embedding(
pixel_values.to(dtype=target_dtype)
) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
@ -132,7 +138,6 @@ class IdeficsVisionAttention(nn.Module):
self.num_heads = self.num_heads // weights.process_group.size()
self.embed_dim = self.embed_dim // weights.process_group.size()
self.k_proj = TensorParallelColumnLinear.load(
config, prefix=f"{prefix}.k_proj", weights=weights, bias=True
)
@ -147,7 +152,11 @@ class IdeficsVisionAttention(nn.Module):
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
return (
tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
.transpose(1, 2)
.contiguous()
)
def forward(
self,
@ -186,7 +195,10 @@ class IdeficsVisionAttention(nn.Module):
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ causal_attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if attention_mask is not None:
@ -194,7 +206,10 @@ class IdeficsVisionAttention(nn.Module):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
attn_weights = (
attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
+ attention_mask
)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
@ -204,12 +219,18 @@ class IdeficsVisionAttention(nn.Module):
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_reshaped = attn_weights.view(
bsz, self.num_heads, tgt_len, src_len
)
attn_weights = attn_weights_reshaped.view(
bsz * self.num_heads, tgt_len, src_len
)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_probs = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
attn_output = torch.bmm(attn_probs, value_states)
@ -253,11 +274,15 @@ class IdeficsVisionEncoderLayer(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_dim = config.hidden_size
self.self_attn = IdeficsVisionAttention(prefix=f"{prefix}.self_attn", config=config, weights=weights)
self.self_attn = IdeficsVisionAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
)
self.layer_norm1 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps
)
self.mlp = IdeficsVisionMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.mlp = IdeficsVisionMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights
)
self.layer_norm2 = nn.LayerNorm.load(
prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps
)
@ -318,7 +343,11 @@ class IdeficsVisionEncoder(nn.Module):
self.config = config
self.layers = nn.ModuleList(
[
IdeficsVisionEncoderLayer(prefix=f"{prefix}.encoder.layers.{layer_id}", config=config, weights=weights)
IdeficsVisionEncoderLayer(
prefix=f"{prefix}.encoder.layers.{layer_id}",
config=config,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
@ -362,11 +391,19 @@ class IdeficsVisionEncoder(nn.Module):
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@ -406,9 +443,15 @@ class IdeficsVisionEncoder(nn.Module):
encoder_states = encoder_states + (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
return tuple(
v
for v in [hidden_states, encoder_states, all_attentions]
if v is not None
)
return BaseModelOutput(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
@ -419,13 +462,19 @@ class IdeficsVisionTransformer(nn.Module):
self.config = config
embed_dim = config.hidden_size
self.embeddings = IdeficsVisionEmbeddings(prefix=f"{prefix}.embeddings", config=config, weights=weights)
self.embeddings = IdeficsVisionEmbeddings(
prefix=f"{prefix}.embeddings", config=config, weights=weights
)
self.pre_layrnorm = nn.LayerNorm.load(
prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps
)
self.encoder = IdeficsVisionEncoder(prefix=prefix, config=config, weights=weights)
self.encoder = IdeficsVisionEncoder(
prefix=prefix, config=config, weights=weights
)
self.post_layernorm = nn.LayerNorm.load(
prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps
prefix=f"{prefix}.post_layernorm",
weights=weights,
eps=config.layer_norm_eps,
)
# copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
@ -440,11 +489,19 @@ class IdeficsVisionTransformer(nn.Module):
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

View File

@ -49,7 +49,10 @@ from text_generation_server.utils.layers import (
CUSTOM_KERNELS_ENABLED = False
if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True":
if (
torch.cuda.is_available()
and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True"
):
try:
from custom_kernels import fused_attention_cuda

View File

@ -1005,9 +1005,12 @@ class FlashCausalLM(Model):
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids,
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
generated_text = GeneratedText(
output_text,

View File

@ -8,7 +8,13 @@ import re
from dataclasses import dataclass
from opentelemetry import trace
from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin
from transformers import (
AutoProcessor,
AutoTokenizer,
AutoModelForCausalLM,
PreTrainedTokenizerBase,
ProcessorMixin,
)
from typing import Optional, Tuple, List, Type, Dict
from text_generation_server.models import Model
@ -23,7 +29,8 @@ from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sam
import re
IMAGES = re.compile(r'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)')
IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)")
def split(string):
parts = []
@ -41,6 +48,7 @@ def split(string):
return parts
tracer = trace.get_tracer(__name__)
@ -94,7 +102,7 @@ class IdeficsCausalLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
processor: ProcessorMixin, # Hack
processor: ProcessorMixin, # Hack
dtype: torch.dtype,
device: torch.device,
) -> "IdeficsCausalLMBatch":
@ -137,12 +145,16 @@ class IdeficsCausalLMBatch(Batch):
padding=True,
truncation=True,
max_length=max_truncation,
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(input_len - 5) # To decode without potential fallbacks errors
read_offsets.append(input_len) # To decode without potential fallbacks errors
prefix_offsets.append(
input_len - 5
) # To decode without potential fallbacks errors
read_offsets.append(
input_len
) # To decode without potential fallbacks errors
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
@ -158,14 +170,21 @@ class IdeficsCausalLMBatch(Batch):
attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"]
# Do the same for image_attention_mask
image_attention_mask = input_ids.new_zeros(
(pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1))
(
pb.size,
max_input_length + padding_right_offset,
tokenized_inputs["pixel_values"].size(1),
)
)
image_attention_mask[:, :max_input_length, :] = tokenized_inputs["image_attention_mask"]
image_attention_mask[:, :max_input_length, :] = tokenized_inputs[
"image_attention_mask"
]
position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1
position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1)
all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
all_input_ids = tokenized_inputs["input_ids"].T.split(
1, dim=1
) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list
max_tokens = len(inputs) * (max_input_length + max_decode_tokens)
@ -259,7 +278,7 @@ class IdeficsCausalLMBatch(Batch):
self.image_attention_mask.shape[1] - self.padding_right_offset
)
+ new_padding_right_offset,
:
:,
]
if self.image_hidden_states is None:
image_hidden_states = None
@ -308,7 +327,9 @@ class IdeficsCausalLMBatch(Batch):
@classmethod
@tracer.start_as_current_span("concatenate")
def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch":
def concatenate(
cls, batches: List["IdeficsCausalLMBatch"]
) -> "IdeficsCausalLMBatch":
# It adds new requests to the batch
# Used for padding
total_batch_size = 0
@ -383,12 +404,20 @@ class IdeficsCausalLMBatch(Batch):
curr_batch_max_num_images = batch.pixel_values.size(1)
if pixel_values is None:
pixel_values = batch.pixel_values.new_zeros((total_batch_size, max_num_images, 3, 224, 224))
pixel_values[start_index:end_index, :curr_batch_max_num_images] = batch.pixel_values
pixel_values = batch.pixel_values.new_zeros(
(total_batch_size, max_num_images, 3, 224, 224)
)
pixel_values[
start_index:end_index, :curr_batch_max_num_images
] = batch.pixel_values
if image_attention_mask is None:
image_attention_mask = batch.image_attention_mask.new_zeros(
(total_batch_size, max_input_length + padding_right_offset, max_num_images)
(
total_batch_size,
max_input_length + padding_right_offset,
max_num_images,
)
)
# We need to slice the attention mask to remove padding from previous steps
@ -409,11 +438,9 @@ class IdeficsCausalLMBatch(Batch):
image_attention_mask[
start_index:end_index,
left_offset:-padding_right_offset,
:curr_batch_max_num_images
:curr_batch_max_num_images,
] = batch.image_attention_mask[
:,
batch_left_offset : - batch.padding_right_offset,
:
:, batch_left_offset : -batch.padding_right_offset, :
]
# Create empty tensor
@ -550,7 +577,9 @@ class IdeficsCausalLM(Model):
dtype: Optional[torch.dtype] = None,
trust_remote_code: bool = False,
):
from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text
from text_generation_server.models.custom_modeling.idefics_modeling import (
IdeficsForVisionText2Text,
)
if torch.cuda.is_available():
device = torch.device("cuda")
@ -650,9 +679,13 @@ class IdeficsCausalLM(Model):
# this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated
# token need to attend to the encoder hidden states (i.e. the vision encoder)
# Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic
image_attention_mask = batch.image_attention_mask[:, -(batch.padding_right_offset+1)].unsqueeze(1)
image_attention_mask = batch.image_attention_mask[
:, -(batch.padding_right_offset + 1)
].unsqueeze(1)
else:
image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset]
image_attention_mask = batch.image_attention_mask[
:, : -batch.padding_right_offset
]
logits, past, image_hidden_states = self.forward(
input_ids=batch.input_ids,
@ -725,9 +758,12 @@ class IdeficsCausalLM(Model):
# Decode generated tokens
output_text, _, _ = self.decode_token(
all_input_ids[:, 0],
prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1,
read_offset=len(all_input_ids) - stopping_criteria.current_tokens,
skip_special_tokens=True
prefix_offset=len(all_input_ids)
- stopping_criteria.current_tokens
- 1,
read_offset=len(all_input_ids)
- stopping_criteria.current_tokens,
skip_special_tokens=True,
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
@ -761,7 +797,7 @@ class IdeficsCausalLM(Model):
else:
prefill_tokens = None
top_tokens=None
top_tokens = None
generation = Generation(
request.id,
@ -771,7 +807,7 @@ class IdeficsCausalLM(Model):
next_token_text,
next_token_id_squeezed.item() in self.all_special_ids,
generated_text,
top_tokens
top_tokens,
)
generations.append(generation)
@ -793,7 +829,9 @@ class IdeficsCausalLM(Model):
# Update attention_mask as we added a new token to input_ids
batch.attention_mask[:, -batch.padding_right_offset] = 1
batch.image_attention_mask[:, -batch.padding_right_offset, :] = batch.image_attention_mask[:, -(batch.padding_right_offset+1), :]
batch.image_attention_mask[
:, -batch.padding_right_offset, :
] = batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :]
# Decrease right offset
batch.padding_right_offset -= 1

View File

@ -71,7 +71,8 @@ class Model(ABC):
# The prefix text is necessary only to defeat cleanup algorithms in the decode
# which decide to add a space or not depending on the surrounding ids.
prefix_text = self.tokenizer.decode(
all_input_ids[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens
all_input_ids[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
)
new_text = self.tokenizer.decode(
all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens

View File

@ -712,9 +712,11 @@ class Seq2SeqLM(Model):
# Decode all tokens
output_text, _, _ = self.decode_token(
all_decoder_input_ids,
prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1,
prefix_offset=len(all_decoder_input_ids)
- decoder_input_length
- 1,
read_offset=len(all_decoder_input_ids) - decoder_input_length,
skip_special_tokens=True
skip_special_tokens=True,
)
# Get seed

View File

@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache
@ -26,7 +27,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
# Force inference mode for the lifetime of TextGenerationService
self._inference_mode_raii_guard = torch._C._InferenceMode(True)
async def Info(self, request, context):
return self.model.info
@ -55,9 +55,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
async def Warmup(self, request, context):
if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call
if (
self.model.batch_type == IdeficsCausalLMBatch
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(
@ -70,9 +76,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
)
async def Prefill(self, request, context):
if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call
if (
self.model.batch_type == IdeficsCausalLMBatch
): # Hack, i would rather use kwargs in the `from_pb` call
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device
request.batch,
self.model.tokenizer,
self.model.processor,
self.model.dtype,
self.model.device,
)
else:
batch = self.model.batch_type.from_pb(

View File

@ -11,7 +11,7 @@ import awq_inference_engine # with CUDA kernels
# super().__init__()
# self.act = module
# self.scales = nn.Parameter(scales.data)
#
#
# def forward(self, x):
# return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
@ -19,10 +19,10 @@ import awq_inference_engine # with CUDA kernels
class WQLinear(nn.Module):
def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = qweight.shape[0]
self.out_features = qweight.shape[1] * 32 // w_bit
@ -42,7 +42,9 @@ class WQLinear(nn.Module):
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
out_shape = x.shape[:-1] + (self.out_features,)
out = awq_inference_engine.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)

View File

@ -578,7 +578,9 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code):
return trainloader, valenc
def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False):
def get_loaders(
name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False
):
if "wikitext2" in name:
return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code)
if "ptb" in name:
@ -927,7 +929,7 @@ def quantize(
seed=seed,
model_id=model_id,
seqlen=model.seqlen,
trust_remote_code=trust_remote_code
trust_remote_code=trust_remote_code,
)
tick = time.time()

View File

@ -22,7 +22,7 @@ from text_generation_server.utils.gptq.quant_linear import QuantLinear
HAS_AWQ = True
try:
try:
from text_generation_server.utils.awq.quantize.qmodule import WQLinear
except ImportError:
HAS_AWQ = False
@ -36,17 +36,19 @@ CAN_EXLLAMA = major >= 8
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA = False
elif CAN_EXLLAMA:
try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
HAS_EXLLAMA = True
except ImportError:
pass
try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
HAS_EXLLAMA = True
except ImportError:
pass
from typing import Optional
HAS_EETQ = False
try:
from EETQ import quant_weights, w8_a16_gemm
HAS_EETQ = True
except ImportError:
pass
@ -74,12 +76,18 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps):
ln.bias = None
return ln
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = nn.Parameter(weight)
conv2d.bias = nn.Parameter(bias)
@ -87,10 +95,17 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st
@classmethod
def load_conv2d_no_bias(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
def load_conv2d_no_bias(
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride)
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = nn.Parameter(weight)
conv2d.bias = None
@ -215,7 +230,10 @@ class Linear4bit(nn.Module):
def __init__(self, weight, bias, quant_type):
super().__init__()
self.weight = Params4bit(
weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type
weight.data,
requires_grad=False,
compress_statistics=True,
quant_type=quant_type,
)
self.compute_dtype = None
self.weight.cuda(weight.device)
@ -246,7 +264,10 @@ class Linear4bit(nn.Module):
@lru_cache(1)
def warn_deprecate_bnb():
logger.warning("Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce")
logger.warning(
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
)
def get_linear(weight, bias, quantize):
if quantize is None:
@ -255,7 +276,9 @@ def get_linear(weight, bias, quantize):
if HAS_EETQ:
linear = EETQLinear(weight, bias)
else:
raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ")
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "bitsandbytes":
warn_deprecate_bnb()
linear = Linear8bitLt(
@ -305,7 +328,14 @@ def get_linear(weight, bias, quantize):
raise NotImplementedError(
f"The passed weight is not `awq` compatible, loader needs to be updated."
)
linear = WQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None)
linear = WQLinear(
w_bit=bits,
group_size=groupsize,
qweight=qweight,
qzeros=qzeros,
scales=scales,
bias=bias is not None,
)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
@ -392,9 +422,7 @@ class TensorParallelColumnLinear(SuperLayer):
@classmethod
def load_qkv(cls, config, prefix: str, weights, bias: bool):
"""Specific method when the QKV was joined after the fact"""
weight = weights.get_weights_col_packed_qkv(
prefix, quantize=config.quantize
)
weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize)
if bias:
raise NotImplementedError("packed_qkv only implemented for baichuan")
else:
@ -530,14 +558,16 @@ try:
def _create_inv_freq(dim, base, device):
inv_freq = 1.0 / (
base
** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)
)
return inv_freq
def _get_rope_config(config):
if os.getenv("ROPE_SCALING", None) is not None:
rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])}
rope_scaling = {
"type": os.environ["ROPE_SCALING"],
"factor": float(os.environ["ROPE_FACTOR"]),
}
return rope_scaling
return getattr(config, "rope_scaling", None)
@ -563,9 +593,17 @@ try:
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor)
return DynamicPositionRotaryEmbedding(
dim=dim,
max_position_embeddings=config.max_position_embeddings,
base=base,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor)
@classmethod
@ -583,9 +621,17 @@ try:
if rope_scaling["type"] == "linear":
pass
elif rope_scaling["type"] == "dynamic":
return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor)
return DynamicPositionRotaryEmbedding(
dim=2 * inv_freq.shape[0],
max_position_embeddings=config.max_position_embeddings,
base=10000.0,
device=inv_freq.device,
scaling_factor=scaling_factor,
)
else:
raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid")
raise NotImplementedError(
f"rope scaling type {rope_scaling['type']} is not implemented or invalid"
)
return cls(inv_freq, scaling_factor)
def _update_cos_sin_cache(self, dtype, device, seqlen):
@ -645,8 +691,13 @@ try:
or self._cos_cached.dtype != dtype
):
if seqlen > self.max_position_embeddings:
newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device)
newbase = self.base * (
(self.scaling_factor * seqlen / self.max_position_embeddings)
- (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
self.inv_freq = _create_inv_freq(
self.dim, newbase, self.inv_freq.device
)
self._seq_len_cached = seqlen
t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16
@ -656,6 +707,5 @@ try:
self._cos_cached = torch.cos(freqs).to(dtype)
self._sin_cached = torch.sin(freqs).to(dtype)
except ImportError:
pass

View File

@ -6,6 +6,7 @@ import torch
from transformers import AutoTokenizer
from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM
def download_and_unload_peft(model_id, revision, trust_remote_code):
torch_dtype = torch.float16
@ -33,7 +34,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
base_model_id = model.peft_config["default"].base_model_name_or_path
model = model.merge_and_unload()
os.makedirs(model_id, exist_ok=True)
cache_dir = model_id
logger.info(f"Saving the newly created merged model to {cache_dir}")
@ -41,6 +42,3 @@ def download_and_unload_peft(model_id, revision, trust_remote_code):
model.save_pretrained(cache_dir, safe_serialization=True)
model.config.save_pretrained(cache_dir)
tokenizer.save_pretrained(cache_dir)

View File

@ -363,7 +363,7 @@ def batch_top_tokens(
# Find the new "fuzzy" top n values
top_n_indices = (logprobs >= nth_highest).nonzero()
_, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True)
k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max()
# Take a new topk for these new max n values
top_k = torch.topk(logprobs, k=k, dim=1, sorted=True)

View File

@ -62,7 +62,7 @@ class Weights:
def get_shape(self, tensor_name: str):
return self._get_slice(tensor_name).get_shape()
def get_tensor(self, tensor_name: str, to_device = True):
def get_tensor(self, tensor_name: str, to_device=True):
filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename)
tensor = f.get_tensor(tensor_name)
@ -110,7 +110,6 @@ class Weights:
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim)
def _get_qweight(self, name: str):
slice_ = self._get_slice(name)
total_size = slice_.get_shape()[1]
@ -119,14 +118,16 @@ class Weights:
world_size = self.process_group.size()
rank = self.process_group.rank()
assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
assert (
single_size % world_size == 0
), f"Prepacked quantized qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[:, start:stop]
k = slice_[:, start+single_size:stop+single_size]
v = slice_[:, start+2*single_size:stop+2*single_size]
weight = torch.cat([q,k,v], dim=1)
k = slice_[:, start + single_size : stop + single_size]
v = slice_[:, start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q, k, v], dim=1)
weight = weight.to(device=self.device)
return weight
@ -137,14 +138,14 @@ class Weights:
"""
if quantize in ["gptq", "awq"]:
try:
qweight = self._get_qweight(f"{prefix}.qweight")
qweight = self._get_qweight(f"{prefix}.qweight")
except RuntimeError:
raise RuntimeError(
f"Cannot load `{quantize}` weight, make sure the model is already quantized."
)
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
qzeros = self._get_qweight(f"{prefix}.qzeros")
scales = self._get_qweight(f"{prefix}.scales")
scales = scales.to(dtype=self.dtype)
if quantize == "gptq":
g_idx = self.get_tensor(f"{prefix}.g_idx")
@ -154,21 +155,23 @@ class Weights:
bits, groupsize = self._get_gptq_params()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
slice_ = self._get_slice(f"{prefix}.weight")
slice_ = self._get_slice(f"{prefix}.weight")
total_size = slice_.get_shape()[0]
assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3"
single_size = total_size // 3
world_size = self.process_group.size()
rank = self.process_group.rank()
assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards"
assert (
single_size % world_size == 0
), f"Prepacked qkv cannot be sharded across {world_size} shards"
block_size = single_size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q = slice_[start:stop]
k = slice_[start+single_size:stop+single_size]
v = slice_[start+2*single_size:stop+2*single_size]
weight = torch.cat([q,k,v], dim=0)
k = slice_[start + single_size : stop + single_size]
v = slice_[start + 2 * single_size : stop + 2 * single_size]
weight = torch.cat([q, k, v], dim=0)
weight = weight.to(device=self.device)
weight = weight.to(dtype=self.dtype)
return weight
@ -205,7 +208,7 @@ class Weights:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
return weight
def get_tensor_shard(self, var, dim):
world_size = self.process_group.size()
rank = self.process_group.rank()
@ -220,7 +223,7 @@ class Weights:
raise NotImplementedError("Let's make that generic when needed")
tensor = tensor.to(dtype=self.dtype)
tensor = tensor.to(device=self.device)
return tensor
return tensor
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
@ -303,7 +306,7 @@ class Weights:
scales = self.get_sharded(f"{prefix}.scales", dim=0)
g_idx = None
use_exllama = False
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)