From bf700e7eef4771f280c19dbc7270c8c7c20efbbc Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 26 Feb 2024 19:49:28 +0100 Subject: [PATCH] Revamp medusa implementation so that every model can benefit. (#1588) # What does this PR do? Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. --- integration-tests/conftest.py | 8 ++ integration-tests/models/test_flash_medusa.py | 4 +- server/text_generation_server/cli.py | 14 +-- .../text_generation_server/models/__init__.py | 86 +++++++++-------- server/text_generation_server/models/bloom.py | 6 +- .../models/causal_lm.py | 13 ++- .../models/custom_modeling/bloom_modeling.py | 21 +++-- .../custom_modeling/flash_gemma_modeling.py | 10 +- .../custom_modeling/flash_llama_modeling.py | 10 +- .../custom_modeling/flash_mistral_modeling.py | 4 +- .../custom_modeling/flash_mixtral_modeling.py | 4 +- .../custom_modeling/flash_neox_modeling.py | 4 +- .../custom_modeling/flash_phi_modeling.py | 4 +- .../custom_modeling/flash_rw_modeling.py | 6 +- .../flash_santacoder_modeling.py | 4 +- .../custom_modeling/idefics_modeling.py | 29 +++--- .../models/custom_modeling/mamba_modeling.py | 11 +-- .../models/custom_modeling/mpt_modeling.py | 21 +++-- .../models/custom_modeling/neox_modeling.py | 4 +- .../models/custom_modeling/opt_modeling.py | 4 +- .../models/custom_modeling/phi_modeling.py | 4 +- .../models/custom_modeling/t5_modeling.py | 31 ++++--- .../models/flash_causal_lm.py | 24 +++-- .../models/flash_gemma.py | 33 +------ .../models/flash_llama.py | 34 +------ .../models/flash_mistral.py | 24 ++++- .../models/flash_mixtral.py | 2 + .../models/flash_neox.py | 2 + .../models/flash_phi.py | 3 +- .../text_generation_server/models/flash_rw.py | 2 + .../models/flash_santacoder.py | 2 + .../text_generation_server/models/idefics.py | 2 + .../models/idefics_causal_lm.py | 11 ++- server/text_generation_server/models/mamba.py | 18 +++- server/text_generation_server/models/mpt.py | 2 + server/text_generation_server/models/opt.py | 2 + server/text_generation_server/models/phi.py | 2 + .../models/santacoder.py | 1 + .../models/seq2seq_lm.py | 11 ++- server/text_generation_server/models/t5.py | 5 +- server/text_generation_server/utils/hub.py | 2 + server/text_generation_server/utils/layers.py | 92 ++++++++++++++++++- server/text_generation_server/utils/medusa.py | 59 ------------ 43 files changed, 352 insertions(+), 283 deletions(-) delete mode 100644 server/text_generation_server/utils/medusa.py diff --git a/integration-tests/conftest.py b/integration-tests/conftest.py index 80457bc2..e11c7cf9 100644 --- a/integration-tests/conftest.py +++ b/integration-tests/conftest.py @@ -236,6 +236,7 @@ def launcher(event_loop): use_flash_attention: bool = True, disable_grammar_support: bool = False, dtype: Optional[str] = None, + revision: Optional[str] = None, ): port = random.randint(8000, 10_000) master_port = random.randint(10_000, 20_000) @@ -268,6 +269,9 @@ def launcher(event_loop): if dtype is not None: args.append("--dtype") args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) if trust_remote_code: args.append("--trust-remote-code") @@ -302,6 +306,7 @@ def launcher(event_loop): use_flash_attention: bool = True, disable_grammar_support: bool = False, dtype: Optional[str] = None, + revision: Optional[str] = None, ): port = random.randint(8000, 10_000) @@ -317,6 +322,9 @@ def launcher(event_loop): if dtype is not None: args.append("--dtype") args.append(dtype) + if revision is not None: + args.append("--revision") + args.append(revision) if trust_remote_code: args.append("--trust-remote-code") diff --git a/integration-tests/models/test_flash_medusa.py b/integration-tests/models/test_flash_medusa.py index e0cc1039..27db5665 100644 --- a/integration-tests/models/test_flash_medusa.py +++ b/integration-tests/models/test_flash_medusa.py @@ -3,7 +3,9 @@ import pytest @pytest.fixture(scope="module") def flash_medusa_handle(launcher): - with launcher("FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2) as handle: + with launcher( + "FasterDecoding/medusa-vicuna-7b-v1.3", num_shard=2, revision="refs/pr/1" + ) as handle: yield handle diff --git a/server/text_generation_server/cli.py b/server/text_generation_server/cli.py index b74fbe36..a513f5e6 100644 --- a/server/text_generation_server/cli.py +++ b/server/text_generation_server/cli.py @@ -154,12 +154,8 @@ def download_weights( import json medusa_head = hf_hub_download( - model_id, revision=revision, filename="medusa_lm_head.pt" + model_id, revision=revision, filename="medusa_lm_head.safetensors" ) - if auto_convert: - medusa_sf = Path(medusa_head[: -len(".pt")] + ".safetensors") - if not medusa_sf.exists(): - utils.convert_files([Path(medusa_head)], [medusa_sf], []) medusa_config = hf_hub_download( model_id, revision=revision, filename="config.json" ) @@ -198,16 +194,12 @@ def download_weights( if not extension == ".safetensors" or not auto_convert: raise e - elif (Path(model_id) / "medusa_lm_head.pt").exists(): + elif (Path(model_id) / "medusa_lm_head.safetensors").exists(): # Try to load as a local Medusa model try: import json - medusa_head = Path(model_id) / "medusa_lm_head.pt" - if auto_convert: - medusa_sf = Path(model_id) / "medusa_lm_head.safetensors" - if not medusa_sf.exists(): - utils.convert_files([Path(medusa_head)], [medusa_sf], []) + medusa_head = Path(model_id) / "medusa_lm_head.safetensors" medusa_config = Path(model_id) / "config.json" with open(medusa_config, "r") as f: config = json.load(f) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index abab3486..3208275c 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -3,7 +3,9 @@ import torch from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto +from huggingface_hub import hf_hub_download from typing import Optional +from pathlib import Path from text_generation_server.utils.speculate import get_speculate, set_speculate from text_generation_server.models.model import Model @@ -115,44 +117,14 @@ def get_model( else: set_speculate(0) - if "facebook/galactica" in model_id: - return GalacticaSharded( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - - if model_id.startswith("bigcode/"): - if FLASH_ATTENTION: - return FlashSantacoderSharded( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - elif sharded: - raise NotImplementedError( - FLASH_ATT_ERROR_MESSAGE.format("Sharded Santacoder") - ) - else: - return SantaCoder( - model_id, - revision, - quantize=quantize, - dtype=dtype, - trust_remote_code=trust_remote_code, - ) - config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) use_medusa = None if "medusa_num_heads" in config_dict: - use_medusa = model_id + medusa_model_id = model_id + medusa_revision = revision model_id = config_dict["base_model_name_or_path"] revision = "main" speculate_medusa = config_dict["medusa_num_heads"] @@ -169,6 +141,20 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + is_local = Path(medusa_model_id).exists() + if not is_local: + medusa_config = hf_hub_download( + medusa_model_id, revision=medusa_revision, filename="config.json" + ) + hf_hub_download( + medusa_model_id, + revision=medusa_revision, + filename="medusa_lm_head.safetensors", + ) + use_medusa = Path(medusa_config).parent + else: + use_medusa = Path(medusa_model_id) + method = "medusa" else: method = "n-gram" @@ -193,16 +179,22 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) - if model_type == "gpt_bigcode": + if ( + model_type == "gpt_bigcode" + or model_type == "gpt2" + and model_id.startswith("bigcode/") + ): if FLASH_ATTENTION: return FlashSantacoderSharded( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -215,6 +207,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -224,6 +217,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -232,6 +226,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -242,6 +237,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -250,6 +246,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -258,6 +255,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -268,15 +266,16 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, - use_medusa=use_medusa, ) else: return CausalLM( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -291,6 +290,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -301,9 +301,9 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, - use_medusa=use_medusa, ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Llama")) @@ -312,6 +312,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -321,9 +322,9 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, - use_medusa=use_medusa, ) elif sharded: raise NotImplementedError( @@ -334,6 +335,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -347,6 +349,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -357,6 +360,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -365,6 +369,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -378,6 +383,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -391,6 +397,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -400,6 +407,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -409,6 +417,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -418,6 +427,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -441,6 +451,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -449,6 +460,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -460,6 +472,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -468,6 +481,7 @@ def get_model( model_id, revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index fed5e6f3..67129ec3 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -42,6 +42,7 @@ class BLOOMSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -70,6 +71,7 @@ class BLOOMSharded(CausalLM): ) config.pad_token_id = 3 config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") @@ -103,7 +105,7 @@ class BLOOMSharded(CausalLM): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None ): - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -112,4 +114,4 @@ class BLOOMSharded(CausalLM): ) logits = outputs.logits - return logits, outputs.past_key_values + return logits, speculative_logits, outputs.past_key_values diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index a0f0c9e8..bbcef210 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -482,6 +482,7 @@ class CausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -550,7 +551,9 @@ class CausalLM(Model): def forward( self, input_ids, attention_mask, position_ids, past_key_values: Optional = None - ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: + ) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]] + ]: # Model Forward kwargs = { "input_ids": input_ids, @@ -563,7 +566,11 @@ class CausalLM(Model): kwargs["position_ids"] = position_ids outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values + if isinstance(outputs, tuple): + outputs, speculative_logits = outputs + else: + speculative_logits = None + return outputs.logits, speculative_logits, outputs.past_key_values @tracer.start_as_current_span("generate_token") def generate_token( @@ -573,7 +580,7 @@ class CausalLM(Model): # slice the attention mask to the correct shape attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] - logits, past = self.forward( + logits, speculative_logits, past = self.forward( batch.input_ids, attention_mask, batch.position_ids, diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 5423d75a..10b40483 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -36,7 +36,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) CUSTOM_KERNELS_ENABLED = False @@ -820,7 +820,7 @@ class BloomForCausalLM(BloomPreTrainedModel): super().__init__(config) self.transformer = BloomModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="word_embeddings", weights=weights, @@ -904,17 +904,20 @@ class BloomForCausalLM(BloomPreTrainedModel): ) hidden_states = transformer_outputs[0] - lm_logits = self.lm_head(hidden_states) + logits, speculative_logits = self.lm_head(hidden_states) loss = None if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output - return CausalLMOutputWithCrossAttentions( - loss=loss, - logits=lm_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, + return ( + CausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ), + speculative_logits, ) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index 4a08bc2a..e91927df 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -37,7 +37,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, ) @@ -575,7 +575,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): super().__init__() self.model = FlashGemmaModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="model.embed_tokens" if config.tie_word_embeddings else "lm_head", weights=weights, @@ -592,7 +592,7 @@ class FlashGemmaForCausalLM(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, position_ids, @@ -605,5 +605,5 @@ class FlashGemmaForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) - return logits + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 1626eb4d..3a269fc0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -32,7 +32,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, ) @@ -410,7 +410,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): super().__init__() self.model = FlashLlamaModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, @@ -427,7 +427,7 @@ class FlashLlamaForCausalLM(torch.nn.Module): input_lengths: torch.Tensor, max_s: int, lm_head_indices: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.model( input_ids, position_ids, @@ -440,5 +440,5 @@ class FlashLlamaForCausalLM(torch.nn.Module): ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices] - logits = self.lm_head(hidden_states) - return logits + logits, speculative_logits = self.lm_head(hidden_states) + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index fda34e5a..ed9306e0 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -32,7 +32,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastRMSNorm, ) @@ -419,7 +419,7 @@ class FlashMistralForCausalLM(torch.nn.Module): super().__init__() self.model = MistralModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 3d3caba3..17d4f708 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -37,7 +37,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, ) @@ -810,7 +810,7 @@ class FlashMixtralForCausalLM(torch.nn.Module): super().__init__() self.model = MixtralModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 780861c2..ee062d3d 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -33,7 +33,7 @@ from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelHead, + SpeculativeHead, FastLayerNorm, PositionRotaryEmbedding, get_linear, @@ -369,7 +369,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): super().__init__(config) self.gpt_neox = FlashGPTNeoXModel(config, weights) - self.embed_out = TensorParallelHead.load( + self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index a9a929e9..cfe447a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -12,7 +12,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, PositionRotaryEmbedding, - TensorParallelHead, + SpeculativeHead, get_linear, FastLayerNorm, ) @@ -376,7 +376,7 @@ class FlashPhiForCausalLM(torch.nn.Module): super().__init__() self.model = FlashPhiModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights, diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 6a530f3c..a9127d1f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -12,7 +12,7 @@ from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelHead, + SpeculativeHead, FastLayerNorm, PositionRotaryEmbedding, get_linear, @@ -613,9 +613,7 @@ class FlashRWForCausalLM(FlashRWPreTrainedModel): self.transformer = FlashRWModel(config, weights) - self.lm_head = TensorParallelHead.load( - config, prefix="lm_head", weights=weights - ) + self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) def forward( self, diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index d3fe95d0..bbb603a7 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -9,7 +9,7 @@ from text_generation_server.utils import paged_attention, flash_attn from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, - TensorParallelHead, + SpeculativeHead, TensorParallelEmbedding, FastLayerNorm, get_linear, @@ -453,7 +453,7 @@ class FlashSantacoderForCausalLM(nn.Module): def __init__(self, config, weights): super().__init__() self.transformer = FlashSantacoderModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/idefics_modeling.py b/server/text_generation_server/models/custom_modeling/idefics_modeling.py index 4f7dfb95..ee4cdb08 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_modeling.py +++ b/server/text_generation_server/models/custom_modeling/idefics_modeling.py @@ -51,7 +51,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, PositionRotaryEmbedding, FastLinear, ) @@ -272,9 +272,7 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module): weights, ) -> None: super().__init__() - self.fc = TensorParallelHead.load( - config=config, prefix="lm_head", weights=weights - ) + self.fc = SpeculativeHead.load(config=config, prefix="lm_head", weights=weights) self.additional_fc = FastLinear.load( config=config, prefix="lm_head.additional_fc", @@ -283,11 +281,11 @@ class IdeficsDecoupledTensorParallelLinear(nn.Module): ) def forward(self, input: torch.Tensor) -> torch.Tensor: - output = self.fc(input) + output, speculative_logits = self.fc(input) additional_features = self.additional_fc(input) output = torch.cat((output, additional_features), -1) - return output + return output, speculative_logits def extra_repr(self) -> str: """Overwriting `nn.Linear.extra_repr` to include new parameters.""" @@ -1503,17 +1501,20 @@ class IdeficsForVisionText2Text(IdeficsPreTrainedModel): ) hidden_states = outputs[0] - logits = self.lm_head(hidden_states) + logits, speculative_logits = self.lm_head(hidden_states) loss = None - return CausalLMOutputWithPastImage( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - image_hidden_states=outputs.image_hidden_states, + return ( + CausalLMOutputWithPastImage( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=outputs.image_hidden_states, + ), + speculative_logits, ) def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index baf1fb85..c58a617f 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -9,6 +9,7 @@ from transformers.configuration_utils import PretrainedConfig import torch.nn.functional as F from text_generation_server.utils.layers import ( + SpeculativeHead, TensorParallelEmbedding, FastRMSNorm, FastLinear, @@ -205,14 +206,12 @@ class MambaModel(nn.Module): self.norm_f = FastRMSNorm.load( f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon ) - self.lm_head = FastLinear.load( - config, f"{prefix}.embedding", weights, bias=False - ) + self.lm_head = SpeculativeHead.load(config, f"{prefix}.embedding", weights) self.config = config def forward( self, input_ids: torch.Tensor, inference_params=None, residual=None - ) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: hidden_states = self.embed_tokens(input_ids) for i, block in enumerate(self.blocks): hidden_states, residual, conv_state, ssm_state = block( @@ -226,8 +225,8 @@ class MambaModel(nn.Module): ) hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) hidden_states = hidden_states.view(residual.shape) - logits = self.lm_head(hidden_states) + logits, speculative_logits = self.lm_head(hidden_states) # update the offset for the next inference using these params inference_params.seqlen_offset += input_ids.size(1) - return logits + return logits, speculative_logits diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index 2e2e423e..9b0f8b92 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -21,7 +21,7 @@ from text_generation_server.utils.layers import ( TensorParallelEmbedding, TensorParallelColumnLinear, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, get_linear, ) @@ -1090,7 +1090,7 @@ class MPTForCausalLM(MPTPreTrainedModel): if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") self.transformer = MPTModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="transformer.wte", weights=weights ) self.logit_scale = None @@ -1133,7 +1133,7 @@ class MPTForCausalLM(MPTPreTrainedModel): output_hidden_states=output_hidden_states, use_cache=use_cache, ) - logits = self.lm_head(outputs.last_hidden_state) + logits, speculative_logits = self.lm_head(outputs.last_hidden_state) if self.logit_scale is not None: if self.logit_scale == 0: warnings.warn( @@ -1147,12 +1147,15 @@ class MPTForCausalLM(MPTPreTrainedModel): loss = F.cross_entropy( logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1) ) - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, + return ( + CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ), + speculative_logits, ) def prepare_inputs_for_generation( diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index dbcefbae..2550d2d1 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -44,7 +44,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) @@ -646,7 +646,7 @@ class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): def __init__(self, config, weights): super().__init__(config) self.gpt_neox = GPTNeoXModel(config, weights) - self.embed_out = TensorParallelHead.load( + self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index ce3f5e21..de5e95af 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -32,7 +32,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) EPS = 1e-5 @@ -748,7 +748,7 @@ class OPTForCausalLM(OPTPreTrainedModel): self.model = OPTModel(config, weights) - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="model.decoder.embed_tokens", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index e5c09728..1571f9fd 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -13,7 +13,7 @@ from text_generation_server.utils.layers import ( TensorParallelRowLinear, TensorParallelColumnLinear, TensorParallelEmbedding, - TensorParallelHead, + SpeculativeHead, FastLinear, ) @@ -120,7 +120,7 @@ class PhiCausalLMHead(nn.Module): weights=weights, eps=config.layer_norm_epsilon, ) - self.linear = TensorParallelHead.load( + self.linear = SpeculativeHead.load( config=config, prefix="lm_head.linear", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/t5_modeling.py b/server/text_generation_server/models/custom_modeling/t5_modeling.py index d3e4f53a..2773fb15 100644 --- a/server/text_generation_server/models/custom_modeling/t5_modeling.py +++ b/server/text_generation_server/models/custom_modeling/t5_modeling.py @@ -42,7 +42,7 @@ from text_generation_server.utils.layers import ( TensorParallelColumnLinear, TensorParallelEmbedding, TensorParallelRowLinear, - TensorParallelHead, + SpeculativeHead, ) @@ -1033,14 +1033,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ) try: - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="lm_head", weights=weights ) except RuntimeError: # Some models like t5-small were saved with shared weights unlike flan # Since they are declared as the same arch we have no choice but hope # that this is OK instead of using a proper flag. - self.lm_head = TensorParallelHead.load( + self.lm_head = SpeculativeHead.load( config, prefix="shared", weights=weights ) @@ -1126,7 +1126,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 sequence_output = sequence_output * (self.model_dim**-0.5) - lm_logits = self.lm_head(sequence_output) + logits, speculative_logits = self.lm_head(sequence_output) loss = None if labels is not None: @@ -1140,16 +1140,19 @@ class T5ForConditionalGeneration(T5PreTrainedModel): output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs return ((loss,) + output) if loss is not None else output - return Seq2SeqLMOutput( - loss=loss, - logits=lm_logits, - past_key_values=decoder_outputs.past_key_values, - decoder_hidden_states=decoder_outputs.hidden_states, - decoder_attentions=decoder_outputs.attentions, - cross_attentions=decoder_outputs.cross_attentions, - encoder_last_hidden_state=encoder_outputs.last_hidden_state, - encoder_hidden_states=encoder_outputs.hidden_states, - encoder_attentions=encoder_outputs.attentions, + return ( + Seq2SeqLMOutput( + loss=loss, + logits=logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ), + speculative_logits, ) def prepare_inputs_for_generation( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index b8d0be22..988637d4 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -723,7 +723,7 @@ class FlashCausalLM(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - self.cuda_graphs[bs]["logits"] = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, @@ -734,6 +734,8 @@ class FlashCausalLM(Model): max_s=max_s, lm_head_indices=None, ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() def warmup(self, batch: FlashCausalLMBatch): @@ -805,7 +807,9 @@ class FlashCausalLM(Model): return int(num_blocks * BLOCK_SIZE) - def forward(self, batch: FlashCausalLMBatch) -> torch.Tensor: + def forward( + self, batch: FlashCausalLMBatch + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids @@ -900,9 +904,14 @@ class FlashCausalLM(Model): # Replay the graph cuda_graph["graph"].replay() - # Slice output to the correct shape - return cuda_graph["logits"][:bs] + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits @tracer.start_as_current_span("generate_token") def generate_token( @@ -926,16 +935,11 @@ class FlashCausalLM(Model): batch.slots = slots try: - out = self.forward(batch) + out, speculative_logits = self.forward(batch) except Exception as e: del batch raise e - if isinstance(out, tuple): - out, speculative_logits = out - else: - speculative_logits = None - if prefill: next_token_logits = ( out[batch.prefill_next_token_indices] if prefill_logprobs else out diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 220b3992..8cfb6631 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -25,9 +25,9 @@ class FlashGemma(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -50,6 +50,7 @@ class FlashGemma(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) @@ -59,36 +60,6 @@ class FlashGemma(FlashCausalLM): weights._set_gptq_params(model_id, revision) model = FlashGemmaForCausalLM(config, weights) - if use_medusa: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(use_medusa).exists() and Path(use_medusa).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(use_medusa) / "config.json") - medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) torch.distributed.barrier(group=self.process_group) super(FlashGemma, self).__init__( diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 94bd58f4..a2ac759a 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -26,9 +26,9 @@ class FlashLlama(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -58,6 +58,7 @@ class FlashLlama(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) @@ -67,37 +68,6 @@ class FlashLlama(FlashCausalLM): weights._set_gptq_params(model_id, revision) model = FlashLlamaForCausalLM(config, weights) - if use_medusa: - from text_generation_server.utils.medusa import MedusaModel - from huggingface_hub import hf_hub_download - import json - import os - from pathlib import Path - - is_local_model = ( - Path(use_medusa).exists() and Path(use_medusa).is_dir() - ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None - - if not is_local_model: - medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" - ) - medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" - ) - else: - medusa_config = str(Path(use_medusa) / "config.json") - medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") - - with open(medusa_config, "r") as f: - config = json.load(f) - medusa_sf = medusa_head[: -len(".pt")] + ".safetensors" - weights = Weights( - [medusa_sf], device, dtype, process_group=self.process_group - ) - lm_head = model.lm_head - model.lm_head = MedusaModel(config, weights, lm_head) - torch.distributed.barrier(group=self.process_group) super(FlashLlama, self).__init__( model=model, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 5df4e214..d3c0da9c 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -294,6 +294,7 @@ class BaseFlashMistral(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -319,6 +320,7 @@ class BaseFlashMistral(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa # Set context windows if config.sliding_window is not None: @@ -394,7 +396,7 @@ class BaseFlashMistral(FlashCausalLM): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - self.cuda_graphs[bs]["logits"] = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=None, @@ -406,9 +408,13 @@ class BaseFlashMistral(FlashCausalLM): prefill_cache_indices=None, lm_head_indices=None, ) + self.cuda_graphs[bs]["logits"] = logits + self.cuda_graphs[bs]["speculative_logits"] = speculative_logits torch.cuda.synchronize() - def forward(self, batch: FlashMistralBatch) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, batch: FlashMistralBatch + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: # Model Forward if batch.speculative_ids is not None: input_ids = batch.input_ids @@ -479,7 +485,7 @@ class BaseFlashMistral(FlashCausalLM): cuda_graph = self.cuda_graphs.get(padded_bs, None) if cu_seqlen_prefill is not None or cuda_graph is None: - logits = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, position_ids=position_ids, cu_seqlen_prefill=cu_seqlen_prefill, @@ -493,7 +499,7 @@ class BaseFlashMistral(FlashCausalLM): ) if batch.prefill_cache_indices is not None: batch.prefill_cache_indices = None - return logits + return logits, speculative_logits # Copy inputs to the static inputs of the cuda graph # Static inputs are potentially padded @@ -511,7 +517,13 @@ class BaseFlashMistral(FlashCausalLM): cuda_graph["graph"].replay() # Slice output to the correct shape - return cuda_graph["logits"][:bs] + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits class FlashMistral(BaseFlashMistral): @@ -520,6 +532,7 @@ class FlashMistral(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -529,6 +542,7 @@ class FlashMistral(BaseFlashMistral): model_id=model_id, revision=revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py index 6f77a658..2ee35e82 100644 --- a/server/text_generation_server/models/flash_mixtral.py +++ b/server/text_generation_server/models/flash_mixtral.py @@ -15,6 +15,7 @@ class FlashMixtral(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -24,6 +25,7 @@ class FlashMixtral(BaseFlashMistral): model_id=model_id, revision=revision, quantize=quantize, + use_medusa=use_medusa, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 80f8804d..5a351bd7 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -24,6 +24,7 @@ class FlashNeoXSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -46,6 +47,7 @@ class FlashNeoXSharded(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index 061b9740..cb55f9e6 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -25,9 +25,9 @@ class FlashPhi(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, - use_medusa: Optional[str] = None, ): self.process_group, rank, world_size = initialize_torch_distributed() if torch.cuda.is_available(): @@ -48,6 +48,7 @@ class FlashPhi(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index dfab8888..fc1e26bd 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -25,6 +25,7 @@ class FlashRWSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -61,6 +62,7 @@ class FlashRWSharded(FlashCausalLM): ) config.quantize = quantize + config.use_medusa = use_medusa if config.quantize == "gptq": weights._set_gptq_params(model_id, revision) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 22171ec0..034949f9 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -27,6 +27,7 @@ class FlashSantacoderSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,6 +52,7 @@ class FlashSantacoderSharded(FlashCausalLM): trust_remote_code=True, ) config.quantize = quantize + config.use_medusa = use_medusa config.transpose = config.architectures[0].startswith("GPT2") torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index fa23d1f9..baa1945b 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -31,6 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,6 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa config.vision_config.quantize = quantize tokenizer = LlamaTokenizerFast.from_pretrained( diff --git a/server/text_generation_server/models/idefics_causal_lm.py b/server/text_generation_server/models/idefics_causal_lm.py index a6df2ebe..c96e8152 100644 --- a/server/text_generation_server/models/idefics_causal_lm.py +++ b/server/text_generation_server/models/idefics_causal_lm.py @@ -662,8 +662,13 @@ class IdeficsCausalLM(Model): if self.has_position_ids: kwargs["position_ids"] = position_ids - outputs = self.model.forward(**kwargs) - return outputs.logits, outputs.past_key_values, outputs.image_hidden_states + outputs, speculative_logits = self.model.forward(**kwargs) + return ( + outputs.logits, + speculative_logits, + outputs.past_key_values, + outputs.image_hidden_states, + ) @tracer.start_as_current_span("generate_token") def generate_token( @@ -686,7 +691,7 @@ class IdeficsCausalLM(Model): :, : -batch.padding_right_offset ] - logits, past, image_hidden_states = self.forward( + logits, speculative_logits, past, image_hidden_states = self.forward( input_ids=batch.input_ids, attention_mask=attention_mask, position_ids=batch.position_ids, diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 9d59f424..2500d454 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -408,6 +408,7 @@ class Mamba(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -444,6 +445,7 @@ class Mamba(Model): tokenizer.pad_token = tokenizer.eos_token config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) @@ -505,7 +507,7 @@ class Mamba(Model): torch.cuda.synchronize() with torch.cuda.graph(graph, pool=MEM_POOL): - logits = self.model.forward( + logits, speculative_logits = self.model.forward( input_ids=input_ids, inference_params=inference_params ) torch.cuda.synchronize() @@ -514,6 +516,7 @@ class Mamba(Model): "inference_params": inference_params, "graph": graph, "logits": logits, + "speculative_logits": speculative_logits, } self.cuda_graphs[batch_size] = graph_dict @@ -556,9 +559,14 @@ class Mamba(Model): inference_params.ssm_states.copy_( cuda_graph["inference_params"].ssm_states[:, :bs] ) - # Slice output to the correct shape - return cuda_graph["logits"][:bs] + speculative_logits = ( + cuda_graph["speculative_logits"][:bs] + if cuda_graph["speculative_logits"] is not None + else None + ) + logits = cuda_graph["logits"][:bs] + return logits, speculative_logits def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: start = time.time_ns() @@ -589,7 +597,9 @@ class Mamba(Model): batch.inference_params = inference_params # Forward pass - logits = self.forward(input_ids, inference_params=batch.inference_params) + logits, speculative_logits = self.forward( + input_ids, inference_params=batch.inference_params + ) # batch.inference_params = new_inference_params # Results diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index e419467f..6b3f29a6 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -43,6 +43,7 @@ class MPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -75,6 +76,7 @@ class MPTSharded(CausalLM): config = json.load(f) config = PretrainedConfig(**config) config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 58fb212f..703e5b58 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -22,6 +22,7 @@ class OPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -47,6 +48,7 @@ class OPTSharded(CausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index 79aa3fb9..cc4e2505 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -22,6 +22,7 @@ class Phi(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -52,6 +53,7 @@ class Phi(CausalLM): tokenizer.pad_token = tokenizer.eos_token config.quantize = quantize + config.use_medusa = use_medusa torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 7b269d8e..73c21cce 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -19,6 +19,7 @@ class SantaCoder(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index 777a55ba..fae9a2df 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -532,6 +532,7 @@ class Seq2SeqLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -596,6 +597,7 @@ class Seq2SeqLM(Model): past_key_values: Optional = None, ) -> Tuple[ torch.Tensor, + Optional[torch.Tensor], torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], ]: @@ -609,8 +611,15 @@ class Seq2SeqLM(Model): past_key_values=past_key_values, use_cache=True, ) + if isinstance(outputs, tuple): + # Our custom models + outputs, speculative_logits = outputs + else: + # Generic transformers models + speculative_logits = None return ( outputs.logits, + speculative_logits, outputs.encoder_last_hidden_state, outputs.past_key_values, ) @@ -635,7 +644,7 @@ class Seq2SeqLM(Model): else: encoder_last_hidden_state = None - logits, encoder_last_hidden_state, past = self.forward( + logits, speculative_logits, encoder_last_hidden_state, past = self.forward( batch.input_ids, batch.attention_mask, batch.decoder_input_ids, diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 161e69ba..3f3cb965 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -25,6 +25,7 @@ class T5Sharded(Seq2SeqLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, + use_medusa: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -42,6 +43,7 @@ class T5Sharded(Seq2SeqLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize + config.use_medusa = use_medusa tokenizer = AutoTokenizer.from_pretrained( model_id, @@ -94,7 +96,7 @@ class T5Sharded(Seq2SeqLM): List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]], ]: # Model Forward - outputs = self.model.forward( + outputs, speculative_logits = self.model.forward( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -106,6 +108,7 @@ class T5Sharded(Seq2SeqLM): return ( outputs.logits, + speculative_logits, outputs.encoder_last_hidden_state, outputs.past_key_values, ) diff --git a/server/text_generation_server/utils/hub.py b/server/text_generation_server/utils/hub.py index b56484f6..a81e659d 100644 --- a/server/text_generation_server/utils/hub.py +++ b/server/text_generation_server/utils/hub.py @@ -40,6 +40,7 @@ def _weight_hub_files_from_model_info( and "arguments" not in s.rfilename and "args" not in s.rfilename and "training" not in s.rfilename + and "medusa_lm_head" not in s.rfilename ] @@ -56,6 +57,7 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]: and "args" not in f and "adapter" not in f and "training" not in f + and "medusa_lm_head" not in f ] return filenames diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index bef2a146..209f1c8a 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -4,7 +4,7 @@ import torch.distributed from torch import nn from torch.nn import functional as F -from typing import List +from typing import List, Tuple, Optional from loguru import logger from functools import lru_cache @@ -380,6 +380,96 @@ class SuperLayer(nn.Module): return self.linear.forward(x) +class ResBlock(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.linear = FastLinear.load( + config, prefix=f"{prefix}.linear", weights=weights, bias=True + ) + self.act = torch.nn.SiLU() + + def forward(self, x): + return x + self.act(self.linear(x)) + + +class MedusaModel(torch.nn.Module): + def __init__(self, config, weights): + super().__init__() + self.heads = torch.nn.ModuleList( + [ + MedusaHead(config, prefix=f"{i}", weights=weights) + for i in range(config["medusa_num_heads"]) + ] + ) + + def forward(self, x): + speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) + return speculative_logits + + +class MedusaHead(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.blocks = torch.nn.ModuleList( + [ + ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) + for i in range(config["medusa_num_layers"]) + ] + ) + n = len(self.blocks) + self.out = FastLinear.load( + config, prefix=f"{prefix}.{n}", weights=weights, bias=False + ) + + def forward(self, x): + for block in self.blocks: + x = block(x) + x = self.out(x) + return x + + +class SpeculativeHead(nn.Module): + def __init__(self, lm_head, medusa): + super().__init__() + self.lm_head = lm_head + self.medusa = medusa + + @staticmethod + def load(config, prefix: str, weights): + lm_head = TensorParallelHead.load(config, prefix, weights) + use_medusa = config.use_medusa + if use_medusa: + from pathlib import Path + from safetensors import safe_open + import json + + medusa_config = str(Path(use_medusa) / "config.json") + filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + + with open(medusa_config, "r") as f: + config = json.load(f) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + weights.routing[k] = filename + + medusa = MedusaModel(config, weights) + else: + medusa = None + return SpeculativeHead(lm_head, medusa) + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self.lm_head(input) + speculative_logits = self.medusa(input) if self.medusa is not None else None + return logits, speculative_logits + + class TensorParallelHead(SuperLayer): def __init__(self, linear, process_group, should_gather: bool): super().__init__(linear) diff --git a/server/text_generation_server/utils/medusa.py b/server/text_generation_server/utils/medusa.py deleted file mode 100644 index 634119cb..00000000 --- a/server/text_generation_server/utils/medusa.py +++ /dev/null @@ -1,59 +0,0 @@ -import torch -from dataclasses import dataclass -from text_generation_server.utils.layers import TensorParallelHead, FastLinear - - -@dataclass -class Output: - logits: torch.FloatTensor = None - speculative_logits: torch.FloatTensor = None - - -class ResBlock(torch.nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.linear = FastLinear.load( - config, prefix=f"{prefix}.linear", weights=weights, bias=True - ) - self.act = torch.nn.SiLU() - - def forward(self, x): - return x + self.act(self.linear(x)) - - -class MedusaModel(torch.nn.Module): - def __init__(self, config, weights, lm_head): - super().__init__() - self.heads = torch.nn.ModuleList( - [ - MedusaHead(config, prefix=f"{i}", weights=weights) - for i in range(config["medusa_num_heads"]) - ] - ) - self.lm_head = lm_head - - def forward(self, x): - logits = self.lm_head(x) - speculative_logits = torch.stack([head(x) for head in self.heads], dim=1) - return logits, speculative_logits - - -class MedusaHead(torch.nn.Module): - def __init__(self, config, prefix, weights): - super().__init__() - self.blocks = torch.nn.ModuleList( - [ - ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) - for i in range(config["medusa_num_layers"]) - ] - ) - n = len(self.blocks) - self.out = FastLinear.load( - config, prefix=f"{prefix}.{n}", weights=weights, bias=False - ) - - def forward(self, x): - for block in self.blocks: - x = block(x) - x = self.out(x) - return x