diff --git a/clients/python/text_generation/client.py b/clients/python/text_generation/client.py index 015613c2..ff7f66a3 100644 --- a/clients/python/text_generation/client.py +++ b/clients/python/text_generation/client.py @@ -137,7 +137,7 @@ class Client: typical_p=typical_p, watermark=watermark, decoder_input_details=decoder_input_details, - top_n_tokens=top_n_tokens + top_n_tokens=top_n_tokens, ) request = Request(inputs=prompt, stream=False, parameters=parameters) diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 20083b19..6d6a0536 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -133,7 +133,9 @@ class Request(BaseModel): and parameters.best_of > 1 and field_value ): - raise ValidationError("`best_of` != 1 is not supported when `stream` == True") + raise ValidationError( + "`best_of` != 1 is not supported when `stream` == True" + ) return field_value diff --git a/integration-tests/models/test_flash_awq.py b/integration-tests/models/test_flash_awq.py index f0b99a3b..62a95f48 100644 --- a/integration-tests/models/test_flash_awq.py +++ b/integration-tests/models/test_flash_awq.py @@ -3,7 +3,11 @@ import pytest @pytest.fixture(scope="module") def flash_llama_awq_handle(launcher): - with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=1, quantize="awq") as handle: + with launcher( + "abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", + num_shard=1, + quantize="awq", + ) as handle: yield handle @@ -12,6 +16,7 @@ async def flash_llama_awq(flash_llama_awq_handle): await flash_llama_awq_handle.health(300) return flash_llama_awq_handle.client + @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_awq(flash_llama_awq, response_snapshot): @@ -20,11 +25,13 @@ async def test_flash_llama_awq(flash_llama_awq, response_snapshot): ) assert response.details.generated_tokens == 10 - assert response.generated_text == "\nWhat is the difference between Deep Learning and Machine" + assert ( + response.generated_text + == "\nWhat is the difference between Deep Learning and Machine" + ) assert response == response_snapshot - @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): @@ -49,16 +56,18 @@ async def test_flash_llama_awq_all_params(flash_llama_awq, response_snapshot): @pytest.mark.asyncio @pytest.mark.private -async def test_flash_llama_awq_load( - flash_llama_awq, generate_load, response_snapshot -): +async def test_flash_llama_awq_load(flash_llama_awq, generate_load, response_snapshot): responses = await generate_load( flash_llama_awq, "What is Deep Learning?", max_new_tokens=10, n=4 ) assert len(responses) == 4 - assert all([r.generated_text == "\nWhat is the difference between Deep Learning and Machine" for r in responses]) + assert all( + [ + r.generated_text + == "\nWhat is the difference between Deep Learning and Machine" + for r in responses + ] + ) assert responses == response_snapshot - - diff --git a/integration-tests/models/test_flash_awq_sharded.py b/integration-tests/models/test_flash_awq_sharded.py index 39ea464a..1c687fc9 100644 --- a/integration-tests/models/test_flash_awq_sharded.py +++ b/integration-tests/models/test_flash_awq_sharded.py @@ -1,15 +1,22 @@ import pytest + @pytest.fixture(scope="module") def flash_llama_awq_handle_sharded(launcher): - with launcher("abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", num_shard=2, quantize="awq") as handle: + with launcher( + "abhinavkulkarni/codellama-CodeLlama-7b-Python-hf-w4-g128-awq", + num_shard=2, + quantize="awq", + ) as handle: yield handle + @pytest.fixture(scope="module") async def flash_llama_awq_sharded(flash_llama_awq_handle_sharded): await flash_llama_awq_handle_sharded.health(300) return flash_llama_awq_handle_sharded.client + @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapshot): @@ -18,9 +25,13 @@ async def test_flash_llama_awq_sharded(flash_llama_awq_sharded, response_snapsho ) assert response.details.generated_tokens == 10 - assert response.generated_text == "\nWhat is the difference between Deep Learning and Machine" + assert ( + response.generated_text + == "\nWhat is the difference between Deep Learning and Machine" + ) assert response == response_snapshot + @pytest.mark.asyncio @pytest.mark.private async def test_flash_llama_awq_load_sharded( @@ -31,6 +42,12 @@ async def test_flash_llama_awq_load_sharded( ) assert len(responses) == 4 - assert all([r.generated_text == "\nWhat is the difference between Deep Learning and Machine" for r in responses]) + assert all( + [ + r.generated_text + == "\nWhat is the difference between Deep Learning and Machine" + for r in responses + ] + ) assert responses == response_snapshot diff --git a/integration-tests/models/test_idefics.py b/integration-tests/models/test_idefics.py index 5659dd5c..5f4571b5 100644 --- a/integration-tests/models/test_idefics.py +++ b/integration-tests/models/test_idefics.py @@ -3,9 +3,7 @@ import pytest @pytest.fixture(scope="module") def idefics_handle(launcher): - with launcher( - "HuggingFaceM4/idefics-9b-instruct", num_shard=2 - ) as handle: + with launcher("HuggingFaceM4/idefics-9b-instruct", num_shard=2) as handle: yield handle diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index 4187ff25..0585f1fb 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -45,12 +45,15 @@ def test_stopping_criteria_max(): assert criteria(1, "") == (False, None) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) + def test_batch_top_tokens(): top_n_tokens = [0, 2, 3, 4, 5] top_n_tokens_tensor = torch.tensor(top_n_tokens) - inp_logprobs = torch.tensor([[-1., -3., -4., -2., -3.]] * 5) + inp_logprobs = torch.tensor([[-1.0, -3.0, -4.0, -2.0, -3.0]] * 5) - topn_tok_ids, topn_tok_logprobs = batch_top_tokens(top_n_tokens, top_n_tokens_tensor, inp_logprobs) + topn_tok_ids, topn_tok_logprobs = batch_top_tokens( + top_n_tokens, top_n_tokens_tensor, inp_logprobs + ) assert topn_tok_ids[0] == [] assert topn_tok_ids[1] == [0, 3] diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index cf9596c9..301acb6b 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -125,8 +125,12 @@ def download_weights( if not is_local_model: try: - adapter_config_filename = hf_hub_download(model_id, revision=revision, filename="adapter_config.json") - utils.download_and_unload_peft(model_id, revision, trust_remote_code=trust_remote_code) + adapter_config_filename = hf_hub_download( + model_id, revision=revision, filename="adapter_config.json" + ) + utils.download_and_unload_peft( + model_id, revision, trust_remote_code=trust_remote_code + ) is_local_model = True utils.weight_files(model_id, revision, extension) return @@ -179,11 +183,12 @@ def download_weights( import transformers import json - if is_local_model: config_filename = os.path.join(model_id, "config.json") else: - config_filename = hf_hub_download(model_id, revision=revision, filename="config.json") + config_filename = hf_hub_download( + model_id, revision=revision, filename="config.json" + ) with open(config_filename, "r") as f: config = json.load(f) architecture = config["architectures"][0] diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 0d96d43b..96dd1ed1 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -153,7 +153,11 @@ def get_model( ) elif model_type == "mpt": return MPTSharded( - model_id, revision, quantize=quantize, dtype=dtype, trust_remote_code=trust_remote_code + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, ) elif model_type == "gpt_neox": @@ -252,13 +256,13 @@ def get_model( ) elif model_type == "idefics": if FLASH_ATTENTION: - return IDEFICSSharded( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) + return IDEFICSSharded( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) else: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Idefics")) @@ -269,13 +273,9 @@ def get_model( "gptq quantization is not supported for AutoModel, you can try to quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`" ) if quantize == "awq": - raise ValueError( - "awq quantization is not supported for AutoModel" - ) + raise ValueError("awq quantization is not supported for AutoModel") elif (quantize == "bitsandbytes-fp4") or (quantize == "bitsandbytes-nf4"): - raise ValueError( - "4bit quantization is not supported for AutoModel" - ) + raise ValueError("4bit quantization is not supported for AutoModel") if model_type in modeling_auto.MODEL_FOR_CAUSAL_LM_MAPPING_NAMES: return CausalLM( model_id, diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 35d74b2e..fccfb0f8 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -643,9 +643,12 @@ class CausalLM(Model): # Decode generated tokens output_text, _, _ = self.decode_token( all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, - read_offset=len(all_input_ids) - stopping_criteria.current_tokens, - skip_special_tokens=True + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, ) # Get seed if isinstance(next_token_chooser.choice, Sampling): diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 297d5c68..5423d75a 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -40,7 +40,10 @@ from text_generation_server.utils.layers import ( ) CUSTOM_KERNELS_ENABLED = False -if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": +if ( + torch.cuda.is_available() + and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True" +): try: from custom_kernels import fused_bloom_attention_cuda diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index cb0c1e85..7c743a88 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -169,6 +169,7 @@ def load_attention(config, prefix, weights): bias=False, ) + def _load_gqa(config, prefix: str, weights): assert config.hidden_size % config.num_attention_heads == 0 assert config.num_attention_heads % weights.process_group.size() == 0 @@ -211,7 +212,10 @@ class FlashLlamaAttention(torch.nn.Module): # config=config, prefix=f"{prefix}.rotary_emb", weights=weights # ) self.rotary_emb = PositionRotaryEmbedding.static( - config=config, dim=self.head_size, base=config.rope_theta, device=weights.device + config=config, + dim=self.head_size, + base=config.rope_theta, + device=weights.device, ) self.softmax_scale = self.head_size**-0.5 diff --git a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py index aec9a3dc..6fb00999 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_image_processing.py +++ b/server/text_generation_server/models/custom_modeling/idefics_image_processing.py @@ -20,7 +20,12 @@ import numpy as np from PIL import Image from transformers.image_processing_utils import BaseImageProcessor, BatchFeature -from transformers.image_transforms import resize, to_channel_dimension_format, rescale, normalize +from transformers.image_transforms import ( + resize, + to_channel_dimension_format, + rescale, + normalize, +) from transformers.image_utils import ( ChannelDimension, ImageInput, @@ -121,7 +126,11 @@ class IdeficsImageProcessor(BaseImageProcessor): a PyTorch tensor of the processed images """ image_size = image_size if image_size is not None else self.image_size - image_num_channels = image_num_channels if image_num_channels is not None else self.image_num_channels + image_num_channels = ( + image_num_channels + if image_num_channels is not None + else self.image_num_channels + ) image_mean = image_mean if image_mean is not None else self.image_mean image_std = image_std if image_std is not None else self.image_std size = (image_size, image_size) @@ -160,9 +169,13 @@ class IdeficsImageProcessor(BaseImageProcessor): images = [resize(x, size, resample=PILImageResampling.BICUBIC) for x in images] images = [self.rescale(image=image, scale=1 / 255) for image in images] images = [self.normalize(x, mean=image_mean, std=image_std) for x in images] - images = [to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images] + images = [ + to_channel_dimension_format(x, ChannelDimension.FIRST) for x in images + ] # TODO: this converts to torch tensors - switch to convert_to_tensors once it becomes available - images = BatchFeature(data={"pixel_values": images}, tensor_type=TensorType.PYTORCH)["pixel_values"] + images = BatchFeature( + data={"pixel_values": images}, tensor_type=TensorType.PYTORCH + )["pixel_values"] return images @@ -185,7 +198,9 @@ class IdeficsImageProcessor(BaseImageProcessor): response.raise_for_status() return Image.open(BytesIO(response.content)) else: - raise ValueError(f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}") + raise ValueError( + f"only a single or a list of entries is supported but got type={type(image_url_or_urls)}" + ) def rescale( self, @@ -255,10 +270,9 @@ class IdeficsImageProcessor(BaseImageProcessor): `np.ndarray`: The normalized image. """ # TODO 4.32 - return normalize( - image, mean=mean, std=std, data_format=data_format, **kwargs - ) + return normalize(image, mean=mean, std=std, data_format=data_format, **kwargs) import transformers + transformers.IdeficsImageProcessor = IdeficsImageProcessor diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 8b43ae4d..1ffe6276 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -28,7 +28,11 @@ from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel from transformers.activations import ACT2FN -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, dataclass +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + dataclass, +) from transformers.modeling_utils import PretrainedConfig from transformers.utils import ( add_start_docstrings, @@ -37,8 +41,12 @@ from transformers.utils import ( replace_return_docstrings, ) from text_generation_server.models.custom_modeling.idefics_config import IdeficsConfig -from text_generation_server.models.custom_modeling.idefics_vision import IdeficsVisionTransformer -from text_generation_server.models.custom_modeling.idefics_perceiver import IdeficsPerceiverResampler +from text_generation_server.models.custom_modeling.idefics_vision import ( + IdeficsVisionTransformer, +) +from text_generation_server.models.custom_modeling.idefics_perceiver import ( + IdeficsPerceiverResampler, +) from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -49,10 +57,12 @@ from text_generation_server.utils.layers import ( ) import dropout_layer_norm + @dataclass class BaseModelOutputWithPastImage(BaseModelOutputWithPast): image_hidden_states: Optional[torch.FloatTensor] = None + @dataclass class CausalLMOutputWithPastImage(CausalLMOutputWithPast): image_hidden_states: Optional[torch.FloatTensor] = None @@ -78,25 +88,39 @@ def expand_inputs_for_generation( **model_kwargs, ): expanded_return_idx = ( - torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, expand_size).view(-1).to(input_ids.device) + torch.arange(input_ids.shape[0]) + .view(-1, 1) + .repeat(1, expand_size) + .view(-1) + .to(input_ids.device) ) input_ids = input_ids.index_select(0, expanded_return_idx) if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = token_type_ids.index_select(0, expanded_return_idx) - - if attention_mask is not None: - model_kwargs["attention_mask"] = attention_mask.index_select(0, expanded_return_idx) - model_kwargs["image_attention_mask"] = model_kwargs["image_attention_mask"].index_select( + model_kwargs["token_type_ids"] = token_type_ids.index_select( + 0, expanded_return_idx + ) + + if attention_mask is not None: + model_kwargs["attention_mask"] = attention_mask.index_select( + 0, expanded_return_idx + ) + model_kwargs["image_attention_mask"] = model_kwargs[ + "image_attention_mask" + ].index_select(0, expanded_return_idx) + model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select( 0, expanded_return_idx ) - model_kwargs["pixel_values"] = model_kwargs["pixel_values"].index_select(0, expanded_return_idx) if is_encoder_decoder: if encoder_outputs is None: - raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.") - encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select( + raise ValueError( + "If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined." + ) + encoder_outputs[ + "last_hidden_state" + ] = encoder_outputs.last_hidden_state.index_select( 0, expanded_return_idx.to(encoder_outputs.last_hidden_state.device) ) model_kwargs["encoder_outputs"] = encoder_outputs @@ -120,14 +144,17 @@ def update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder # update token_type_ids with last value if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] - model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + model_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 + ) # update attention masks if not is_encoder_decoder: if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1 + [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], + dim=-1, ) if "image_attention_mask" in model_kwargs: image_attention_mask = model_kwargs["image_attention_mask"] @@ -180,8 +207,12 @@ def freeze_model(model, module_exceptions=[]): } module_exceptions_mapped = [mapping[m] for m in module_exceptions] for module in model.modules(): - if module_exceptions and any([isinstance(module, t) for t in module_exceptions_mapped]): - module.requires_grad_(True) # Explicitely setting it to true to avoid any mistakes + if module_exceptions and any( + [isinstance(module, t) for t in module_exceptions_mapped] + ): + module.requires_grad_( + True + ) # Explicitely setting it to true to avoid any mistakes else: module.requires_grad_(False) return model @@ -195,15 +226,21 @@ class IdeficsDecoupledPartialTPEmbedding(nn.Module): ): super().__init__() self.num_embeddings = config.vocab_size - self.weight = TensorParallelEmbedding(prefix="model.embed_tokens", weights=weights) - self.additional_weight = nn.Parameter(weights.get_tensor(f"model.embed_tokens.additional_embedding.weight")) + self.weight = TensorParallelEmbedding( + prefix="model.embed_tokens", weights=weights + ) + self.additional_weight = nn.Parameter( + weights.get_tensor(f"model.embed_tokens.additional_embedding.weight") + ) def forward(self, input_ids): # Clone so that we don't modify the original input_ids later on input_ids = input_ids.clone() additional_vocab_indices = torch.where(input_ids >= self.num_embeddings) input_ids_additional_vocab = input_ids[additional_vocab_indices] - additional_embeddings = torch.nn.functional.embedding(input_ids_additional_vocab - self.num_embeddings, self.additional_weight) + additional_embeddings = torch.nn.functional.embedding( + input_ids_additional_vocab - self.num_embeddings, self.additional_weight + ) # for successful lookup replace input_ids with 0, the results of these will be discarded anyway input_ids[additional_vocab_indices] = 0 @@ -234,7 +271,10 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module): config=config, prefix="lm_head", weights=weights ) self.additional_fc = FastLinear.load( - config=config, prefix="lm_head.additional_fc", weights=weights, bias=False, + config=config, + prefix="lm_head.additional_fc", + weights=weights, + bias=False, ) def forward(self, input: torch.Tensor) -> torch.Tensor: @@ -257,7 +297,10 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module): # Copied from transformers.models.bart.modeling_bart._make_causal_mask def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, ): """ Make causal mask used for bi-directional self-attention. @@ -269,8 +312,18 @@ def _make_causal_mask( mask = mask.to(dtype) if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + mask = torch.cat( + [ + torch.zeros( + tgt_len, past_key_values_length, dtype=dtype, device=device + ), + mask, + ], + dim=-1, + ) + return mask[None, None, :, :].expand( + bsz, 1, tgt_len, tgt_len + past_key_values_length + ) def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): @@ -284,7 +337,9 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] inverted_mask = 1.0 - expanded_mask - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + return inverted_mask.masked_fill( + inverted_mask.to(torch.bool), torch.finfo(dtype).min + ) class IdeficsRMSNorm(nn.Module): @@ -346,7 +401,6 @@ class IdeficsRMSNorm(nn.Module): if unwrap: normed_hidden_states = normed_hidden_states.view(*shape) - return normed_hidden_states @@ -367,7 +421,10 @@ class IdeficsMLP(nn.Module): bias=False, ) self.down_proj = TensorParallelRowLinear.load( - config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, + config, + prefix=f"{prefix}.down_proj", + weights=weights, + bias=False, ) self.act_fn = ACT2FN[config.hidden_act] @@ -375,7 +432,9 @@ class IdeficsMLP(nn.Module): gate_up_states = self.gate_up_proj(hidden_states) shape = gate_up_states.shape gate_up_states = gate_up_states.view(*shape[:-1], 2, shape[-1] // 2) - return self.down_proj(self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1]) + return self.down_proj( + self.act_fn(gate_up_states[:, :, 0]) * gate_up_states[:, :, 1] + ) # this was adapted from LlamaAttention @@ -445,14 +504,22 @@ class IdeficsAttention(nn.Module): self.qk_layer_norms = qk_layer_norms if self.qk_layer_norms: self.q_layer_norm = IdeficsRMSNorm( - prefix=f"{prefix}.q_layer_norm", weights=weights, eps=config.rms_norm_eps - ) + prefix=f"{prefix}.q_layer_norm", + weights=weights, + eps=config.rms_norm_eps, + ) self.k_layer_norm = IdeficsRMSNorm( - prefix=f"{prefix}.q_layer_norm", weights=weights, eps=config.rms_norm_eps - ) + prefix=f"{prefix}.q_layer_norm", + weights=weights, + eps=config.rms_norm_eps, + ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) def forward( self, @@ -470,20 +537,42 @@ class IdeficsAttention(nn.Module): bsz, q_len, _ = hidden_states.size() if is_cross_attention: - query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2) + query_states = self.q_proj(hidden_states).view( + bsz, q_len, self.num_heads, self.head_dim + ) # .transpose(1, 2) query_states = query_states.transpose(1, 2) - _, kv_len, _ = key_value_states.size() # Note that, in this case, `kv_len` == `kv_seq_len` - key_states = self.k_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) + ( + _, + kv_len, + _, + ) = ( + key_value_states.size() + ) # Note that, in this case, `kv_len` == `kv_seq_len` + key_states = ( + self.k_proj(key_value_states) + .view(bsz, kv_len, self.num_heads, self.head_dim) + .transpose(1, 2) + ) value_states = ( - self.v_proj(key_value_states).view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) + self.v_proj(key_value_states) + .view(bsz, kv_len, self.num_heads, self.head_dim) + .transpose(1, 2) ) else: qkv = self.qkv(hidden_states) - query_states, key_states, value_states = qkv.split(self.num_heads * self.head_dim, dim=2) + query_states, key_states, value_states = qkv.split( + self.num_heads * self.head_dim, dim=2 + ) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim)# . transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim)# .transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ) # .transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_heads, self.head_dim + ) # . transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_heads, self.head_dim + ) # .transpose(1, 2) kv_seq_len = q_len if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] @@ -493,10 +582,14 @@ class IdeficsAttention(nn.Module): ) shape = query_states.shape - query_states = self.rotary_emb(query_states.view(-1, *shape[2:]), cos, sin).view(shape) + query_states = self.rotary_emb( + query_states.view(-1, *shape[2:]), cos, sin + ).view(shape) shape = key_states.shape - key_states = self.rotary_emb(key_states.reshape(-1, *shape[2:]), cos, sin).view(shape) + key_states = self.rotary_emb( + key_states.reshape(-1, *shape[2:]), cos, sin + ).view(shape) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) @@ -571,8 +664,14 @@ class IdeficsDecoderLayer(nn.Module): prefix=f"{prefix}.mlp", weights=weights, ) - self.input_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps) - self.post_attention_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps) + self.input_layernorm = IdeficsRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = IdeficsRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) self.dropout = config.dropout def forward( @@ -583,7 +682,9 @@ class IdeficsDecoderLayer(nn.Module): past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -650,14 +751,22 @@ class IdeficsGatedCrossAttentionLayer(nn.Module): prefix=f"{prefix}.mlp", weights=weights, ) - self.input_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps) - self.post_attention_layernorm = IdeficsRMSNorm(prefix=f"{prefix}.post_attention_layernorm", weights=weights, eps=config.rms_norm_eps) + self.input_layernorm = IdeficsRMSNorm( + prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = IdeficsRMSNorm( + prefix=f"{prefix}.post_attention_layernorm", + weights=weights, + eps=config.rms_norm_eps, + ) self.config = config.dropout self.act_cross_attn = nn.Tanh() self.act_dense = nn.Tanh() - self.alpha_cross_attn = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_cross_attn")) + self.alpha_cross_attn = nn.Parameter( + weights.get_tensor(f"{prefix}.alpha_cross_attn") + ) self.alpha_dense = nn.Parameter(weights.get_tensor(f"{prefix}.alpha_dense")) if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")): @@ -673,7 +782,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module): use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, no_images: Optional[bool] = False, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -695,7 +806,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module): ) if past_key_value is not None: - raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.") + raise NotImplementedError( + "Past key value states are not implemented for Idefics cross attention module." + ) residual = hidden_states @@ -711,7 +824,9 @@ class IdeficsGatedCrossAttentionLayer(nn.Module): # hidden_states = nn.functional.dropout(hidden_states, p=self.config, training=self.training) # when there are no images the model is used in pure language mode gate = 0 if no_images else 1 - hidden_states = residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states + hidden_states = ( + residual + gate * self.act_cross_attn(self.alpha_cross_attn) * hidden_states + ) # Fully Connected residual = hidden_states @@ -896,11 +1011,14 @@ class IdeficsModel(IdeficsPreTrainedModel): self.gated_cross_attn_layers = nn.ModuleList( [ IdeficsGatedCrossAttentionLayer(layer_id, config, weights) - for layer_id in range(num_cross_layers)] + for layer_id in range(num_cross_layers) + ] ) # self.gradient_checkpointing = False - self.norm = IdeficsRMSNorm(prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps) + self.norm = IdeficsRMSNorm( + prefix=f"model.norm", weights=weights, eps=config.rms_norm_eps + ) # self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -932,7 +1050,9 @@ class IdeficsModel(IdeficsPreTrainedModel): # self.embed_tokens = value # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + def _prepare_decoder_attention_mask( + self, attention_mask, input_shape, inputs_embeds, past_key_values_length + ): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] combined_attention_mask = None @@ -946,11 +1066,13 @@ class IdeficsModel(IdeficsPreTrainedModel): if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) + expanded_attn_mask = _expand_mask( + attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ).to(inputs_embeds.device) combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + expanded_attn_mask + if combined_attention_mask is None + else expanded_attn_mask + combined_attention_mask ) return combined_attention_mask @@ -974,23 +1096,35 @@ class IdeficsModel(IdeficsPreTrainedModel): ) -> Union[Tuple, BaseModelOutputWithPastImage]: device = input_ids.device if input_ids is not None else inputs_embeds.device - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) seq_length_with_past = seq_length past_key_values_length = 0 @@ -1006,7 +1140,10 @@ class IdeficsModel(IdeficsPreTrainedModel): elif position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: @@ -1016,29 +1153,52 @@ class IdeficsModel(IdeficsPreTrainedModel): if image_hidden_states is None: if pixel_values is None and image_embeddings is None: - raise ValueError("Either pixel_values and image_embeddings have to be not-None.") + raise ValueError( + "Either pixel_values and image_embeddings have to be not-None." + ) elif pixel_values is not None and image_embeddings is not None: - raise ValueError("You cannot specify both pixel_values and image_embeddings at the same time") + raise ValueError( + "You cannot specify both pixel_values and image_embeddings at the same time" + ) elif pixel_values is not None: no_images = len(torch.nonzero(pixel_values)) == 0 - pixel_values = pixel_values.to(dtype=self.dtype, device=device) # fp16 compatibility + pixel_values = pixel_values.to( + dtype=self.dtype, device=device + ) # fp16 compatibility batch_size, num_images = pixel_values.shape[:2] - pixel_values = pixel_values.contiguous().view(batch_size * num_images, *pixel_values.shape[2:]) + pixel_values = pixel_values.contiguous().view( + batch_size * num_images, *pixel_values.shape[2:] + ) # Get sequence from the vision encoder - image_hidden_states = self.vision_model(pixel_values=pixel_values).last_hidden_state + image_hidden_states = self.vision_model( + pixel_values=pixel_values + ).last_hidden_state elif image_embeddings is not None: - batch_size, num_images, image_seq_len, image_hidden_size = image_embeddings.size() - image_hidden_states = image_embeddings.to(dtype=self.dtype, device=input_ids.device) - image_hidden_states = image_hidden_states.view(batch_size * num_images, image_seq_len, image_hidden_size) + ( + batch_size, + num_images, + image_seq_len, + image_hidden_size, + ) = image_embeddings.size() + image_hidden_states = image_embeddings.to( + dtype=self.dtype, device=input_ids.device + ) + image_hidden_states = image_hidden_states.view( + batch_size * num_images, image_seq_len, image_hidden_size + ) if self.config.use_resampler: image_hidden_states = self.perceiver_resampler(image_hidden_states) - image_seq_len, image_hidden_size = image_hidden_states.size(1), image_hidden_states.size(2) - image_hidden_states = image_hidden_states.view(batch_size, num_images * image_seq_len, image_hidden_size) + image_seq_len, image_hidden_size = image_hidden_states.size( + 1 + ), image_hidden_states.size(2) + image_hidden_states = image_hidden_states.view( + batch_size, num_images * image_seq_len, image_hidden_size + ) else: no_images = False num_images = pixel_values.shape[1] @@ -1050,7 +1210,9 @@ class IdeficsModel(IdeficsPreTrainedModel): text_seq_len = image_attention_mask.size(1) image_attention_mask = image_attention_mask.unsqueeze(-1) image_attention_mask = image_attention_mask.repeat(1, 1, 1, image_seq_len) - image_attention_mask = image_attention_mask.view(batch_size, text_seq_len, num_images * image_seq_len) + image_attention_mask = image_attention_mask.view( + batch_size, text_seq_len, num_images * image_seq_len + ) image_batch_size, image_sequence_length, _ = image_hidden_states.size() image_hidden_shape = (image_batch_size, image_sequence_length) if image_attention_mask is None: @@ -1060,7 +1222,6 @@ class IdeficsModel(IdeficsPreTrainedModel): # if list(image_attention_mask.shape) != [4, 1, 1024, 64]: # raise ValueError(f"Image hidden_states {image_hidden_states.shape} - mask {image_attention_mask.shape} {num_images} {image_seq_len} {text_seq_len}") - # if image_hidden_states is not None: # else: # image_attention_mask = None @@ -1070,10 +1231,15 @@ class IdeficsModel(IdeficsPreTrainedModel): # embed positions if attention_mask is None: attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + (batch_size, seq_length_with_past), + dtype=torch.bool, + device=inputs_embeds.device, ) attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, ) hidden_states = inputs_embeds @@ -1094,7 +1260,9 @@ class IdeficsModel(IdeficsPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None + past_key_value = ( + past_key_values[idx] if past_key_values is not None else None + ) def vblock( main_block, @@ -1194,7 +1362,11 @@ class IdeficsModel(IdeficsPreTrainedModel): next_cache = next_decoder_cache if use_cache else None if not return_dict: - return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) return BaseModelOutputWithPastImage( last_hidden_state=hidden_states, past_key_values=next_cache, @@ -1230,7 +1402,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): inputs_embeds: Optional[torch.FloatTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, image_embeddings: Optional[torch.FloatTensor] = None, - image_hidden_states: Optional[torch.FloatTensor] = None, + image_hidden_states: Optional[torch.FloatTensor] = None, image_attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1264,11 +1436,19 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -1298,7 +1478,7 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states + image_hidden_states=outputs.image_hidden_states, ) def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): @@ -1316,12 +1496,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): return expand_inputs_for_generation(*args, **model_kwargs) @staticmethod - def _update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=False): - return update_model_kwargs_for_generation(outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder) + def _update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=False + ): + return update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder + ) @staticmethod def _reorder_cache(past, beam_idx): reordered_past = () for layer_past in past: - reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx) for past_state in layer_past + ), + ) return reordered_past diff --git a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py index def78390..477d4d70 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_perceiver.py +++ b/server/text_generation_server/models/custom_modeling/idefics_perceiver.py @@ -46,7 +46,8 @@ from text_generation_server.utils.layers import ( TensorParallelRowLinear, ) -EPS=1e-5 +EPS = 1e-5 + class IdeficsPerceiverResampler(nn.Module): def __init__( @@ -78,7 +79,12 @@ class IdeficsPerceiverResampler(nn.Module): """ super().__init__() - self.embed_dim, self.n_heads, self.head_dim, self.n_latents = embed_dim, n_heads, head_dim, n_latents + self.embed_dim, self.n_heads, self.head_dim, self.n_latents = ( + embed_dim, + n_heads, + head_dim, + n_latents, + ) self.qk_layer_norms = config.perceiver_config.qk_layer_norms_perceiver # Create Latents for Perceiver @@ -107,14 +113,16 @@ class IdeficsPerceiverResampler(nn.Module): prefix=f"{prefix}.blocks.{layer_id}.1", intermediate_size=self.intermediate_dim, config=config, - weights=weights + weights=weights, ), ] ) for layer_id in range(depth) ] ) - self.layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS) + self.layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.layer_norm", weights=weights, eps=EPS + ) def forward(self, context: torch.Tensor) -> torch.Tensor: """Resample arbitrary length context & *compress* down to self.n_latents latent embeddings""" @@ -130,25 +138,34 @@ class IdeficsPerceiverResampler(nn.Module): class IdeficsPerceiverAttention(nn.Module): - def __init__(self, - prefix, - config, - embed_dim: int, - n_heads: int, - head_dim: int, - qk_layer_norms: bool, - weights - ) -> None: + def __init__( + self, + prefix, + config, + embed_dim: int, + n_heads: int, + head_dim: int, + qk_layer_norms: bool, + weights, + ) -> None: """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" super().__init__() self.embed_dim, self.n_heads, self.head_dim = embed_dim, n_heads, head_dim self.qk_layer_norms = qk_layer_norms # Normalization & Scaling - self.context_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS) - self.latents_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS) + self.context_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.context_layer_norm", weights=weights, eps=EPS + ) + self.latents_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.latents_layer_norm", weights=weights, eps=EPS + ) if self.qk_layer_norms: - self.q_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS) - self.k_layer_norm = nn.LayerNorm.load(prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS) + self.q_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.q_layer_norm", weights=weights, eps=EPS + ) + self.k_layer_norm = nn.LayerNorm.load( + prefix=f"{prefix}.k_layer_norm", weights=weights, eps=EPS + ) self.qk_scale = self.head_dim**-0.5 @@ -164,10 +181,10 @@ class IdeficsPerceiverAttention(nn.Module): self.q_proj = TensorParallelColumnLinear.load( config=config, prefix=f"{prefix}.q_proj", weights=weights, bias=False ) - self.k_proj = TensorParallelColumnLinear.load( + self.k_proj = TensorParallelColumnLinear.load( config=config, prefix=f"{prefix}.k_proj", weights=weights, bias=False ) - self.v_proj = TensorParallelColumnLinear.load( + self.v_proj = TensorParallelColumnLinear.load( config=config, prefix=f"{prefix}.v_proj", weights=weights, bias=False ) @@ -202,7 +219,12 @@ class IdeficsPerceiverAttention(nn.Module): # Multiheaded Self-Attention w/ stable softmax (subtract per-row max -- `amax` -- before softmax call) # =>> `attn` should be a 2D matrix of shape [n_latents x (context + n_latents)] # einsum.rearrange(x, "bsz seq (heads embed) -> bsz heads seq embed", heads=self.n_heads) - q, k, v = [x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose(1, 2) for x in (q, k, v)] + q, k, v = [ + x.reshape(batch_size, x.shape[1], self.n_heads, self.head_dim).transpose( + 1, 2 + ) + for x in (q, k, v) + ] if self.qk_layer_norms: q = self.q_layer_norm(q) @@ -219,25 +241,34 @@ class IdeficsPerceiverAttention(nn.Module): class IdeficsMLP(nn.Module): - def __init__(self, - prefix, - intermediate_size, - config, - weights, - ): + def __init__( + self, + prefix, + intermediate_size, + config, + weights, + ): """Simple MLP block with intermediate_size and embedding size""" super().__init__() self.embed_dim = config.vision_config.embed_dim self.ln = nn.LayerNorm.load(prefix=f"{prefix}.ln", weights=weights, eps=EPS) self.fc = TensorParallelColumnLinear.load( - config=config, prefix=f"{prefix}.fc", weights=weights, bias=False, + config=config, + prefix=f"{prefix}.fc", + weights=weights, + bias=False, ) self.act = nn.ReLU() self.c_proj = TensorParallelRowLinear.load( - config=config, prefix=f"{prefix}.c_proj", weights=weights, bias=False, + config=config, + prefix=f"{prefix}.c_proj", + weights=weights, + bias=False, ) - def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: + def forward( + self, hidden_states: Optional[Tuple[torch.FloatTensor]] + ) -> torch.FloatTensor: hidden_states = self.ln(hidden_states) hidden_states = self.fc(hidden_states) hidden_states = self.act(hidden_states) diff --git a/server/text_generation_server/models/custom_modeling/idefics_processing.py b/server/text_generation_server/models/custom_modeling/idefics_processing.py index e24fc7bd..0fbcbeeb 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_processing.py +++ b/server/text_generation_server/models/custom_modeling/idefics_processing.py @@ -21,9 +21,16 @@ from urllib.parse import urlparse from transformers.feature_extraction_utils import BatchFeature from transformers.processing_utils import ProcessorMixin -from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, TextInput, TruncationStrategy +from transformers.tokenization_utils_base import ( + BatchEncoding, + PaddingStrategy, + TextInput, + TruncationStrategy, +) from transformers.utils import TensorType, is_torch_available -from text_generation_server.models.custom_modeling.idefics_image_processing import IdeficsImageProcessor +from text_generation_server.models.custom_modeling.idefics_image_processing import ( + IdeficsImageProcessor, +) if is_torch_available(): @@ -124,7 +131,14 @@ class IdeficsProcessor(ProcessorMixin): image_processor_class = "IdeficsImageProcessor" tokenizer_class = "LlamaTokenizerFast" - def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_utterance_token=None, **kwargs): + def __init__( + self, + image_processor, + tokenizer=None, + image_size=224, + add_end_of_utterance_token=None, + **kwargs, + ): if image_processor is None: raise ValueError("You need to specify an `image_processor`.") if tokenizer is None: @@ -142,7 +156,8 @@ class IdeficsProcessor(ProcessorMixin): self.tokenizer_was_trained_with_end_of_utterance_token = ( True - if "" in self.tokenizer.special_tokens_map.get("additional_special_tokens", []) + if "" + in self.tokenizer.special_tokens_map.get("additional_special_tokens", []) else False ) @@ -265,7 +280,9 @@ class IdeficsProcessor(ProcessorMixin): # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it if add_end_of_utterance_token is None: - add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token + add_end_of_utterance_token = ( + self.tokenizer_was_trained_with_end_of_utterance_token + ) # turn non-batched prompts into batched if not any(isinstance(i, list) for i in prompts): @@ -358,10 +375,14 @@ class IdeficsProcessor(ProcessorMixin): current_images = images[:local_max_num_images] if len(current_images) > 0: - padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:]) + padded_image_tensor = torch.zeros( + max_num_images, *current_images.size()[1:] + ) padded_image_tensor[: current_images.size(0)] = current_images else: - padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims) + padded_image_tensor = torch.zeros( + max_num_images, *self.default_image_dims + ) output_images.append(padded_image_tensor) output_input_ids.append(torch.tensor(padded_input_ids)) @@ -373,14 +394,19 @@ class IdeficsProcessor(ProcessorMixin): output_attention_masks = torch.stack(output_attention_masks) if at_least_one_image: - image_attention_mask, _ = image_attention_mask_for_packed_input_ids(output_input_ids, self.tokenizer) + image_attention_mask, _ = image_attention_mask_for_packed_input_ids( + output_input_ids, self.tokenizer + ) image_attention_mask = incremental_to_binary_attention_mask( image_attention_mask, num_classes=max_num_images ) else: # in full language mode we set the image mask to all-0s image_attention_mask = torch.zeros( - output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool + output_input_ids.shape[0], + output_input_ids.shape[1], + 1, + dtype=torch.bool, ) return BatchFeature( diff --git a/server/text_generation_server/models/custom_modeling/idefics_vision.py b/server/text_generation_server/models/custom_modeling/idefics_vision.py index 30f07095..c521dd0a 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_vision.py +++ b/server/text_generation_server/models/custom_modeling/idefics_vision.py @@ -75,7 +75,9 @@ class IdeficsVisionEmbeddings(nn.Module): self.image_size = config.image_size self.patch_size = config.patch_size - self.class_embedding = nn.Parameter(weights.get_tensor(f"{prefix}.class_embedding")) + self.class_embedding = nn.Parameter( + weights.get_tensor(f"{prefix}.class_embedding") + ) self.patch_embedding = nn.Conv2d.load_no_bias( prefix=f"{prefix}.patch_embedding", @@ -91,12 +93,16 @@ class IdeficsVisionEmbeddings(nn.Module): self.position_embedding = TensorParallelEmbedding( prefix="model.vision_model.embeddings.position_embedding", weights=weights ) - self.position_ids = torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device) + self.position_ids = ( + torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device) + ) def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: batch_size = pixel_values.shape[0] target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] + patch_embeds = self.patch_embedding( + pixel_values.to(dtype=target_dtype) + ) # shape = [*, width, grid, grid] patch_embeds = patch_embeds.flatten(2).transpose(1, 2) class_embeds = self.class_embedding.expand(batch_size, 1, -1) @@ -132,7 +138,6 @@ class IdeficsVisionAttention(nn.Module): self.num_heads = self.num_heads // weights.process_group.size() self.embed_dim = self.embed_dim // weights.process_group.size() - self.k_proj = TensorParallelColumnLinear.load( config, prefix=f"{prefix}.k_proj", weights=weights, bias=True ) @@ -147,7 +152,11 @@ class IdeficsVisionAttention(nn.Module): ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) def forward( self, @@ -186,7 +195,10 @@ class IdeficsVisionAttention(nn.Module): f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" f" {causal_attention_mask.size()}" ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + causal_attention_mask + ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) if attention_mask is not None: @@ -194,7 +206,10 @@ class IdeficsVisionAttention(nn.Module): raise ValueError( f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask + attn_weights = ( + attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + attention_mask + ) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -204,12 +219,18 @@ class IdeficsVisionAttention(nn.Module): # make sure that attn_weights keeps its gradient. # In order to do so, attn_weights have to reshaped # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) + attn_weights_reshaped = attn_weights.view( + bsz, self.num_heads, tgt_len, src_len + ) + attn_weights = attn_weights_reshaped.view( + bsz * self.num_heads, tgt_len, src_len + ) else: attn_weights_reshaped = None - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_probs = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) attn_output = torch.bmm(attn_probs, value_states) @@ -253,11 +274,15 @@ class IdeficsVisionEncoderLayer(nn.Module): def __init__(self, prefix, config, weights): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = IdeficsVisionAttention(prefix=f"{prefix}.self_attn", config=config, weights=weights) + self.self_attn = IdeficsVisionAttention( + prefix=f"{prefix}.self_attn", config=config, weights=weights + ) self.layer_norm1 = nn.LayerNorm.load( prefix=f"{prefix}.layer_norm1", weights=weights, eps=config.layer_norm_eps ) - self.mlp = IdeficsVisionMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.mlp = IdeficsVisionMLP( + prefix=f"{prefix}.mlp", config=config, weights=weights + ) self.layer_norm2 = nn.LayerNorm.load( prefix=f"{prefix}.layer_norm2", weights=weights, eps=config.layer_norm_eps ) @@ -318,7 +343,11 @@ class IdeficsVisionEncoder(nn.Module): self.config = config self.layers = nn.ModuleList( [ - IdeficsVisionEncoderLayer(prefix=f"{prefix}.encoder.layers.{layer_id}", config=config, weights=weights) + IdeficsVisionEncoderLayer( + prefix=f"{prefix}.encoder.layers.{layer_id}", + config=config, + weights=weights, + ) for layer_id in range(config.num_hidden_layers) ] ) @@ -362,11 +391,19 @@ class IdeficsVisionEncoder(nn.Module): return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -406,9 +443,15 @@ class IdeficsVisionEncoder(nn.Module): encoder_states = encoder_states + (hidden_states,) if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return tuple( + v + for v in [hidden_states, encoder_states, all_attentions] + if v is not None + ) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=encoder_states, + attentions=all_attentions, ) @@ -419,13 +462,19 @@ class IdeficsVisionTransformer(nn.Module): self.config = config embed_dim = config.hidden_size - self.embeddings = IdeficsVisionEmbeddings(prefix=f"{prefix}.embeddings", config=config, weights=weights) + self.embeddings = IdeficsVisionEmbeddings( + prefix=f"{prefix}.embeddings", config=config, weights=weights + ) self.pre_layrnorm = nn.LayerNorm.load( prefix=f"{prefix}.pre_layrnorm", weights=weights, eps=config.layer_norm_eps ) - self.encoder = IdeficsVisionEncoder(prefix=prefix, config=config, weights=weights) + self.encoder = IdeficsVisionEncoder( + prefix=prefix, config=config, weights=weights + ) self.post_layernorm = nn.LayerNorm.load( - prefix=f"{prefix}.post_layernorm", weights=weights, eps=config.layer_norm_eps + prefix=f"{prefix}.post_layernorm", + weights=weights, + eps=config.layer_norm_eps, ) # copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward @@ -440,11 +489,19 @@ class IdeficsVisionTransformer(nn.Module): Returns: """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index c5b0c7fd..24ba6796 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -49,7 +49,10 @@ from text_generation_server.utils.layers import ( CUSTOM_KERNELS_ENABLED = False -if torch.cuda.is_available() and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True": +if ( + torch.cuda.is_available() + and not os.environ.get("DISABLE_CUSTOM_KERNELS", "False") == "True" +): try: from custom_kernels import fused_attention_cuda diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 12d8efeb..34c7f633 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1005,9 +1005,12 @@ class FlashCausalLM(Model): # Decode generated tokens output_text, _, _ = self.decode_token( all_input_ids, - prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, - read_offset=len(all_input_ids) - stopping_criteria.current_tokens, - skip_special_tokens=True + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, ) generated_text = GeneratedText( output_text, diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index 30cc2299..2472caf6 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -8,7 +8,13 @@ import re from dataclasses import dataclass from opentelemetry import trace -from transformers import AutoProcessor, AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase, ProcessorMixin +from transformers import ( + AutoProcessor, + AutoTokenizer, + AutoModelForCausalLM, + PreTrainedTokenizerBase, + ProcessorMixin, +) from typing import Optional, Tuple, List, Type, Dict from text_generation_server.models import Model @@ -23,7 +29,8 @@ from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sam import re -IMAGES = re.compile(r'!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)') +IMAGES = re.compile(r"!\[[^\]]*\]\((.*?)\s*(\"(?:.*[^\"])\")?\s*\)") + def split(string): parts = [] @@ -41,6 +48,7 @@ def split(string): return parts + tracer = trace.get_tracer(__name__) @@ -94,7 +102,7 @@ class IdeficsCausalLMBatch(Batch): cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - processor: ProcessorMixin, # Hack + processor: ProcessorMixin, # Hack dtype: torch.dtype, device: torch.device, ) -> "IdeficsCausalLMBatch": @@ -137,12 +145,16 @@ class IdeficsCausalLMBatch(Batch): padding=True, truncation=True, max_length=max_truncation, - add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token + add_end_of_utterance_token=False, # Already taken care of inside the prompts, so bypassing the processor's handling of this token ).to(device) for _ in pb.requests: input_len = tokenized_inputs["input_ids"].shape[1] - prefix_offsets.append(input_len - 5) # To decode without potential fallbacks errors - read_offsets.append(input_len) # To decode without potential fallbacks errors + prefix_offsets.append( + input_len - 5 + ) # To decode without potential fallbacks errors + read_offsets.append( + input_len + ) # To decode without potential fallbacks errors input_lengths = tokenized_inputs["attention_mask"].sum(1) max_input_length = input_lengths.max() @@ -158,14 +170,21 @@ class IdeficsCausalLMBatch(Batch): attention_mask[:, :max_input_length] = tokenized_inputs["attention_mask"] # Do the same for image_attention_mask image_attention_mask = input_ids.new_zeros( - (pb.size, max_input_length + padding_right_offset, tokenized_inputs["pixel_values"].size(1)) + ( + pb.size, + max_input_length + padding_right_offset, + tokenized_inputs["pixel_values"].size(1), + ) ) - image_attention_mask[:, :max_input_length, :] = tokenized_inputs["image_attention_mask"] - + image_attention_mask[:, :max_input_length, :] = tokenized_inputs[ + "image_attention_mask" + ] position_ids = tokenized_inputs["attention_mask"].long().cumsum(-1) - 1 position_ids.masked_fill_(tokenized_inputs["attention_mask"] == 0, 1) - all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list + all_input_ids = tokenized_inputs["input_ids"].T.split( + 1, dim=1 + ) # It's input_ids but splitted into a tuple of tensors where each tensor is (seq_len, 1) size. It is then transformed into a list max_tokens = len(inputs) * (max_input_length + max_decode_tokens) @@ -259,7 +278,7 @@ class IdeficsCausalLMBatch(Batch): self.image_attention_mask.shape[1] - self.padding_right_offset ) + new_padding_right_offset, - : + :, ] if self.image_hidden_states is None: image_hidden_states = None @@ -308,7 +327,9 @@ class IdeficsCausalLMBatch(Batch): @classmethod @tracer.start_as_current_span("concatenate") - def concatenate(cls, batches: List["IdeficsCausalLMBatch"]) -> "IdeficsCausalLMBatch": + def concatenate( + cls, batches: List["IdeficsCausalLMBatch"] + ) -> "IdeficsCausalLMBatch": # It adds new requests to the batch # Used for padding total_batch_size = 0 @@ -383,12 +404,20 @@ class IdeficsCausalLMBatch(Batch): curr_batch_max_num_images = batch.pixel_values.size(1) if pixel_values is None: - pixel_values = batch.pixel_values.new_zeros((total_batch_size, max_num_images, 3, 224, 224)) - pixel_values[start_index:end_index, :curr_batch_max_num_images] = batch.pixel_values + pixel_values = batch.pixel_values.new_zeros( + (total_batch_size, max_num_images, 3, 224, 224) + ) + pixel_values[ + start_index:end_index, :curr_batch_max_num_images + ] = batch.pixel_values if image_attention_mask is None: image_attention_mask = batch.image_attention_mask.new_zeros( - (total_batch_size, max_input_length + padding_right_offset, max_num_images) + ( + total_batch_size, + max_input_length + padding_right_offset, + max_num_images, + ) ) # We need to slice the attention mask to remove padding from previous steps @@ -409,11 +438,9 @@ class IdeficsCausalLMBatch(Batch): image_attention_mask[ start_index:end_index, left_offset:-padding_right_offset, - :curr_batch_max_num_images + :curr_batch_max_num_images, ] = batch.image_attention_mask[ - :, - batch_left_offset : - batch.padding_right_offset, - : + :, batch_left_offset : -batch.padding_right_offset, : ] # Create empty tensor @@ -550,7 +577,9 @@ class IdeficsCausalLM(Model): dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - from text_generation_server.models.custom_modeling.idefics_modeling import IdeficsForVisionText2Text + from text_generation_server.models.custom_modeling.idefics_modeling import ( + IdeficsForVisionText2Text, + ) if torch.cuda.is_available(): device = torch.device("cuda") @@ -650,9 +679,13 @@ class IdeficsCausalLM(Model): # this is due to the nature IDEFICS: it's an encoder decoder, and so when decoding, only the currently generated # token need to attend to the encoder hidden states (i.e. the vision encoder) # Also see seq2seq_lm.Seq2SeqLM.generate_token which has roughly the same logic - image_attention_mask = batch.image_attention_mask[:, -(batch.padding_right_offset+1)].unsqueeze(1) + image_attention_mask = batch.image_attention_mask[ + :, -(batch.padding_right_offset + 1) + ].unsqueeze(1) else: - image_attention_mask = batch.image_attention_mask[:, : -batch.padding_right_offset] + image_attention_mask = batch.image_attention_mask[ + :, : -batch.padding_right_offset + ] logits, past, image_hidden_states = self.forward( input_ids=batch.input_ids, @@ -725,9 +758,12 @@ class IdeficsCausalLM(Model): # Decode generated tokens output_text, _, _ = self.decode_token( all_input_ids[:, 0], - prefix_offset=len(all_input_ids) - stopping_criteria.current_tokens - 1, - read_offset=len(all_input_ids) - stopping_criteria.current_tokens, - skip_special_tokens=True + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) + - stopping_criteria.current_tokens, + skip_special_tokens=True, ) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -761,7 +797,7 @@ class IdeficsCausalLM(Model): else: prefill_tokens = None - top_tokens=None + top_tokens = None generation = Generation( request.id, @@ -771,7 +807,7 @@ class IdeficsCausalLM(Model): next_token_text, next_token_id_squeezed.item() in self.all_special_ids, generated_text, - top_tokens + top_tokens, ) generations.append(generation) @@ -793,7 +829,9 @@ class IdeficsCausalLM(Model): # Update attention_mask as we added a new token to input_ids batch.attention_mask[:, -batch.padding_right_offset] = 1 - batch.image_attention_mask[:, -batch.padding_right_offset, :] = batch.image_attention_mask[:, -(batch.padding_right_offset+1), :] + batch.image_attention_mask[ + :, -batch.padding_right_offset, : + ] = batch.image_attention_mask[:, -(batch.padding_right_offset + 1), :] # Decrease right offset batch.padding_right_offset -= 1 diff --git a/server/text_generation_server/models/model.py b/server/text_generation_server/models/model.py index 73329b24..f6e66d30 100644 --- a/server/text_generation_server/models/model.py +++ b/server/text_generation_server/models/model.py @@ -71,7 +71,8 @@ class Model(ABC): # The prefix text is necessary only to defeat cleanup algorithms in the decode # which decide to add a space or not depending on the surrounding ids. prefix_text = self.tokenizer.decode( - all_input_ids[prefix_offset:read_offset], skip_special_tokens=skip_special_tokens + all_input_ids[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, ) new_text = self.tokenizer.decode( all_input_ids[prefix_offset:], skip_special_tokens=skip_special_tokens diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index f67874be..d4d3cd19 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -712,9 +712,11 @@ class Seq2SeqLM(Model): # Decode all tokens output_text, _, _ = self.decode_token( all_decoder_input_ids, - prefix_offset=len(all_decoder_input_ids) - decoder_input_length - 1, + prefix_offset=len(all_decoder_input_ids) + - decoder_input_length + - 1, read_offset=len(all_decoder_input_ids) - decoder_input_length, - skip_special_tokens=True + skip_special_tokens=True, ) # Get seed diff --git a/server/text_generation_server/server.py b/server/text_generation_server/server.py index 67137aaa..75d2b159 100644 --- a/server/text_generation_server/server.py +++ b/server/text_generation_server/server.py @@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2 from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch + class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): def __init__(self, model: Model, cache: Cache, server_urls: List[str]): self.cache = cache @@ -26,7 +27,6 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): # Force inference mode for the lifetime of TextGenerationService self._inference_mode_raii_guard = torch._C._InferenceMode(True) - async def Info(self, request, context): return self.model.info @@ -55,9 +55,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb()) async def Warmup(self, request, context): - if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call + if ( + self.model.batch_type == IdeficsCausalLMBatch + ): # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.dtype, + self.model.device, ) else: batch = self.model.batch_type.from_pb( @@ -70,9 +76,15 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer): ) async def Prefill(self, request, context): - if self.model.batch_type == IdeficsCausalLMBatch: #Hack, i would rather use kwargs in the `from_pb` call + if ( + self.model.batch_type == IdeficsCausalLMBatch + ): # Hack, i would rather use kwargs in the `from_pb` call batch = self.model.batch_type.from_pb( - request.batch, self.model.tokenizer, self.model.processor, self.model.dtype, self.model.device + request.batch, + self.model.tokenizer, + self.model.processor, + self.model.dtype, + self.model.device, ) else: batch = self.model.batch_type.from_pb( diff --git a/server/text_generation_server/utils/awq/quantize/qmodule.py b/server/text_generation_server/utils/awq/quantize/qmodule.py index c658e17f..ca8caf50 100644 --- a/server/text_generation_server/utils/awq/quantize/qmodule.py +++ b/server/text_generation_server/utils/awq/quantize/qmodule.py @@ -11,7 +11,7 @@ import awq_inference_engine # with CUDA kernels # super().__init__() # self.act = module # self.scales = nn.Parameter(scales.data) -# +# # def forward(self, x): # return self.act(x) / self.scales.view(1, 1, -1).to(x.device) @@ -19,10 +19,10 @@ import awq_inference_engine # with CUDA kernels class WQLinear(nn.Module): def __init__(self, w_bit, group_size, qweight, qzeros, scales, bias): super().__init__() - + if w_bit not in [4]: raise NotImplementedError("Only 4-bit are supported for now.") - + self.in_features = qweight.shape[0] self.out_features = qweight.shape[1] * 32 // w_bit @@ -42,7 +42,9 @@ class WQLinear(nn.Module): @torch.no_grad() def forward(self, x): - out_shape = x.shape[:-1] + (self.out_features, ) - out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) + out_shape = x.shape[:-1] + (self.out_features,) + out = awq_inference_engine.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8 + ) out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) diff --git a/server/text_generation_server/utils/gptq/quantize.py b/server/text_generation_server/utils/gptq/quantize.py index 9547d534..ca113d8f 100644 --- a/server/text_generation_server/utils/gptq/quantize.py +++ b/server/text_generation_server/utils/gptq/quantize.py @@ -578,7 +578,9 @@ def get_c4_new(nsamples, seed, seqlen, model_id, trust_remote_code): return trainloader, valenc -def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False): +def get_loaders( + name, nsamples=128, seed=0, seqlen=2048, model_id="", trust_remote_code=False +): if "wikitext2" in name: return get_wikitext2(nsamples, seed, seqlen, model_id, trust_remote_code) if "ptb" in name: @@ -927,7 +929,7 @@ def quantize( seed=seed, model_id=model_id, seqlen=model.seqlen, - trust_remote_code=trust_remote_code + trust_remote_code=trust_remote_code, ) tick = time.time() diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 14cb55cc..8be2463f 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -22,7 +22,7 @@ from text_generation_server.utils.gptq.quant_linear import QuantLinear HAS_AWQ = True -try: +try: from text_generation_server.utils.awq.quantize.qmodule import WQLinear except ImportError: HAS_AWQ = False @@ -36,17 +36,19 @@ CAN_EXLLAMA = major >= 8 if os.getenv("DISABLE_EXLLAMA") == "True": HAS_EXLLAMA = False elif CAN_EXLLAMA: - try: - from text_generation_server.utils.gptq.exllama import Ex4bitLinear - HAS_EXLLAMA = True - except ImportError: - pass + try: + from text_generation_server.utils.gptq.exllama import Ex4bitLinear + + HAS_EXLLAMA = True + except ImportError: + pass from typing import Optional HAS_EETQ = False try: from EETQ import quant_weights, w8_a16_gemm + HAS_EETQ = True except ImportError: pass @@ -74,12 +76,18 @@ def load_layer_norm_no_bias(cls, prefix, weights, eps): ln.bias = None return ln + @classmethod def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): weight = weights.get_tensor(f"{prefix}.weight") bias = weights.get_tensor(f"{prefix}.bias") with init_empty_weights(): - conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride) + conv2d = cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + ) conv2d.weight = nn.Parameter(weight) conv2d.bias = nn.Parameter(bias) @@ -87,10 +95,17 @@ def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, st @classmethod -def load_conv2d_no_bias(cls, prefix, weights, in_channels, out_channels, kernel_size, stride): +def load_conv2d_no_bias( + cls, prefix, weights, in_channels, out_channels, kernel_size, stride +): weight = weights.get_tensor(f"{prefix}.weight") with init_empty_weights(): - conv2d = cls(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride) + conv2d = cls( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + ) conv2d.weight = nn.Parameter(weight) conv2d.bias = None @@ -215,7 +230,10 @@ class Linear4bit(nn.Module): def __init__(self, weight, bias, quant_type): super().__init__() self.weight = Params4bit( - weight.data, requires_grad=False, compress_statistics=True, quant_type=quant_type + weight.data, + requires_grad=False, + compress_statistics=True, + quant_type=quant_type, ) self.compute_dtype = None self.weight.cuda(weight.device) @@ -246,7 +264,10 @@ class Linear4bit(nn.Module): @lru_cache(1) def warn_deprecate_bnb(): - logger.warning("Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce") + logger.warning( + "Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce" + ) + def get_linear(weight, bias, quantize): if quantize is None: @@ -255,7 +276,9 @@ def get_linear(weight, bias, quantize): if HAS_EETQ: linear = EETQLinear(weight, bias) else: - raise ImportError("Please install EETQ from https://github.com/NetEase-FuXi/EETQ") + raise ImportError( + "Please install EETQ from https://github.com/NetEase-FuXi/EETQ" + ) elif quantize == "bitsandbytes": warn_deprecate_bnb() linear = Linear8bitLt( @@ -305,7 +328,14 @@ def get_linear(weight, bias, quantize): raise NotImplementedError( f"The passed weight is not `awq` compatible, loader needs to be updated." ) - linear = WQLinear(w_bit=bits, group_size=groupsize, qweight=qweight, qzeros=qzeros, scales=scales, bias=bias is not None) + linear = WQLinear( + w_bit=bits, + group_size=groupsize, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias is not None, + ) else: raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.") return linear @@ -392,9 +422,7 @@ class TensorParallelColumnLinear(SuperLayer): @classmethod def load_qkv(cls, config, prefix: str, weights, bias: bool): """Specific method when the QKV was joined after the fact""" - weight = weights.get_weights_col_packed_qkv( - prefix, quantize=config.quantize - ) + weight = weights.get_weights_col_packed_qkv(prefix, quantize=config.quantize) if bias: raise NotImplementedError("packed_qkv only implemented for baichuan") else: @@ -530,14 +558,16 @@ try: def _create_inv_freq(dim, base, device): inv_freq = 1.0 / ( - base - ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) + base ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim) ) return inv_freq def _get_rope_config(config): if os.getenv("ROPE_SCALING", None) is not None: - rope_scaling = {"type": os.environ["ROPE_SCALING"], "factor": float(os.environ["ROPE_FACTOR"])} + rope_scaling = { + "type": os.environ["ROPE_SCALING"], + "factor": float(os.environ["ROPE_FACTOR"]), + } return rope_scaling return getattr(config, "rope_scaling", None) @@ -563,9 +593,17 @@ try: if rope_scaling["type"] == "linear": pass elif rope_scaling["type"] == "dynamic": - return DynamicPositionRotaryEmbedding(dim=dim, max_position_embeddings=config.max_position_embeddings, base=base, device=inv_freq.device, scaling_factor=scaling_factor) + return DynamicPositionRotaryEmbedding( + dim=dim, + max_position_embeddings=config.max_position_embeddings, + base=base, + device=inv_freq.device, + scaling_factor=scaling_factor, + ) else: - raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + raise NotImplementedError( + f"rope scaling type {rope_scaling['type']} is not implemented or invalid" + ) return cls(inv_freq, scaling_factor) @classmethod @@ -583,9 +621,17 @@ try: if rope_scaling["type"] == "linear": pass elif rope_scaling["type"] == "dynamic": - return DynamicPositionRotaryEmbedding(dim=2*inv_freq.shape[0], max_position_embeddings=config.max_position_embeddings, base=10000.0, device=inv_freq.device, scaling_factor=scaling_factor) + return DynamicPositionRotaryEmbedding( + dim=2 * inv_freq.shape[0], + max_position_embeddings=config.max_position_embeddings, + base=10000.0, + device=inv_freq.device, + scaling_factor=scaling_factor, + ) else: - raise NotImplementedError(f"rope scaling type {rope_scaling['type']} is not implemented or invalid") + raise NotImplementedError( + f"rope scaling type {rope_scaling['type']} is not implemented or invalid" + ) return cls(inv_freq, scaling_factor) def _update_cos_sin_cache(self, dtype, device, seqlen): @@ -645,8 +691,13 @@ try: or self._cos_cached.dtype != dtype ): if seqlen > self.max_position_embeddings: - newbase = self.base * ((self.scaling_factor * seqlen / self.max_position_embeddings) - (self.scaling_factor - 1)) ** (self.dim / (self.dim - 2)) - self.inv_freq = _create_inv_freq(self.dim, newbase, self.inv_freq.device) + newbase = self.base * ( + (self.scaling_factor * seqlen / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + self.inv_freq = _create_inv_freq( + self.dim, newbase, self.inv_freq.device + ) self._seq_len_cached = seqlen t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) # Don't do einsum, it converts fp32 to fp16 @@ -656,6 +707,5 @@ try: self._cos_cached = torch.cos(freqs).to(dtype) self._sin_cached = torch.sin(freqs).to(dtype) - except ImportError: pass diff --git a/server/text_generation_server/utils/peft.py b/server/text_generation_server/utils/peft.py index be1f9444..e37447dc 100644 --- a/server/text_generation_server/utils/peft.py +++ b/server/text_generation_server/utils/peft.py @@ -6,6 +6,7 @@ import torch from transformers import AutoTokenizer from peft import AutoPeftModelForCausalLM, AutoPeftModelForSeq2SeqLM + def download_and_unload_peft(model_id, revision, trust_remote_code): torch_dtype = torch.float16 @@ -33,7 +34,7 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): base_model_id = model.peft_config["default"].base_model_name_or_path model = model.merge_and_unload() - + os.makedirs(model_id, exist_ok=True) cache_dir = model_id logger.info(f"Saving the newly created merged model to {cache_dir}") @@ -41,6 +42,3 @@ def download_and_unload_peft(model_id, revision, trust_remote_code): model.save_pretrained(cache_dir, safe_serialization=True) model.config.save_pretrained(cache_dir) tokenizer.save_pretrained(cache_dir) - - - diff --git a/server/text_generation_server/utils/tokens.py b/server/text_generation_server/utils/tokens.py index 7b003f1d..f6339d7c 100644 --- a/server/text_generation_server/utils/tokens.py +++ b/server/text_generation_server/utils/tokens.py @@ -363,7 +363,7 @@ def batch_top_tokens( # Find the new "fuzzy" top n values top_n_indices = (logprobs >= nth_highest).nonzero() _, top_n_ishes = torch.unique_consecutive(top_n_indices[:, 0], return_counts=True) - + k = 1 if top_n_ishes.numel() == 0 else top_n_ishes.max() # Take a new topk for these new max n values top_k = torch.topk(logprobs, k=k, dim=1, sorted=True) diff --git a/server/text_generation_server/utils/weights.py b/server/text_generation_server/utils/weights.py index 266fcccb..8a19fd9f 100644 --- a/server/text_generation_server/utils/weights.py +++ b/server/text_generation_server/utils/weights.py @@ -62,7 +62,7 @@ class Weights: def get_shape(self, tensor_name: str): return self._get_slice(tensor_name).get_shape() - def get_tensor(self, tensor_name: str, to_device = True): + def get_tensor(self, tensor_name: str, to_device=True): filename, tensor_name = self.get_filename(tensor_name) f = self._get_handle(filename) tensor = f.get_tensor(tensor_name) @@ -110,7 +110,6 @@ class Weights: ), f"The choosen size {size} is not compatible with sharding on {world_size} shards" return self.get_partial_sharded(tensor_name, dim) - def _get_qweight(self, name: str): slice_ = self._get_slice(name) total_size = slice_.get_shape()[1] @@ -119,14 +118,16 @@ class Weights: world_size = self.process_group.size() rank = self.process_group.rank() - assert single_size % world_size == 0, f"Prepacked quantized qkv cannot be sharded across {world_size} shards" + assert ( + single_size % world_size == 0 + ), f"Prepacked quantized qkv cannot be sharded across {world_size} shards" block_size = single_size // world_size start = rank * block_size stop = (rank + 1) * block_size q = slice_[:, start:stop] - k = slice_[:, start+single_size:stop+single_size] - v = slice_[:, start+2*single_size:stop+2*single_size] - weight = torch.cat([q,k,v], dim=1) + k = slice_[:, start + single_size : stop + single_size] + v = slice_[:, start + 2 * single_size : stop + 2 * single_size] + weight = torch.cat([q, k, v], dim=1) weight = weight.to(device=self.device) return weight @@ -137,14 +138,14 @@ class Weights: """ if quantize in ["gptq", "awq"]: try: - qweight = self._get_qweight(f"{prefix}.qweight") + qweight = self._get_qweight(f"{prefix}.qweight") except RuntimeError: raise RuntimeError( f"Cannot load `{quantize}` weight, make sure the model is already quantized." ) - qzeros = self._get_qweight(f"{prefix}.qzeros") - scales = self._get_qweight(f"{prefix}.scales") + qzeros = self._get_qweight(f"{prefix}.qzeros") + scales = self._get_qweight(f"{prefix}.scales") scales = scales.to(dtype=self.dtype) if quantize == "gptq": g_idx = self.get_tensor(f"{prefix}.g_idx") @@ -154,21 +155,23 @@ class Weights: bits, groupsize = self._get_gptq_params() weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False) else: - slice_ = self._get_slice(f"{prefix}.weight") + slice_ = self._get_slice(f"{prefix}.weight") total_size = slice_.get_shape()[0] assert total_size % 3 == 0, "Prepacked qkv is not divisible by 3" single_size = total_size // 3 world_size = self.process_group.size() rank = self.process_group.rank() - assert single_size % world_size == 0, f"Prepacked qkv cannot be sharded across {world_size} shards" + assert ( + single_size % world_size == 0 + ), f"Prepacked qkv cannot be sharded across {world_size} shards" block_size = single_size // world_size start = rank * block_size stop = (rank + 1) * block_size q = slice_[start:stop] - k = slice_[start+single_size:stop+single_size] - v = slice_[start+2*single_size:stop+2*single_size] - weight = torch.cat([q,k,v], dim=0) + k = slice_[start + single_size : stop + single_size] + v = slice_[start + 2 * single_size : stop + 2 * single_size] + weight = torch.cat([q, k, v], dim=0) weight = weight.to(device=self.device) weight = weight.to(dtype=self.dtype) return weight @@ -205,7 +208,7 @@ class Weights: w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes] weight = torch.cat(w, dim=dim) return weight - + def get_tensor_shard(self, var, dim): world_size = self.process_group.size() rank = self.process_group.rank() @@ -220,7 +223,7 @@ class Weights: raise NotImplementedError("Let's make that generic when needed") tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(device=self.device) - return tensor + return tensor def get_multi_weights_row(self, prefix: str, quantize: str): if quantize == "gptq": @@ -303,7 +306,7 @@ class Weights: scales = self.get_sharded(f"{prefix}.scales", dim=0) g_idx = None use_exllama = False - + weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama) else: weight = self.get_sharded(f"{prefix}.weight", dim=1)