diff --git a/server/text_generation_server/layers/__init__.py b/server/text_generation_server/layers/__init__.py index c3a6c921..c29dd092 100644 --- a/server/text_generation_server/layers/__init__.py +++ b/server/text_generation_server/layers/__init__.py @@ -3,11 +3,11 @@ from text_generation_server.layers.tensor_parallel import ( TensorParallelRowLinear, TensorParallelEmbedding, ) -from text_generation_server.layers.speculative import SpeculativeHead from text_generation_server.layers.linear import ( get_linear, FastLinear, ) +from text_generation_server.layers.speculative import SpeculativeHead # Just to add the `load` methods. from text_generation_server.layers.layernorm import load_layer_norm diff --git a/server/text_generation_server/layers/medusa.py b/server/text_generation_server/layers/medusa.py index 4ac86978..2e9a010f 100644 --- a/server/text_generation_server/layers/medusa.py +++ b/server/text_generation_server/layers/medusa.py @@ -69,21 +69,24 @@ class MedusaHeadV1(nn.Module): from safetensors import safe_open import json - use_medusa = config.use_medusa + speculator = config.speculator - medusa_config = str(Path(use_medusa) / "config.json") - filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + path = speculator["path"] + medusa_config = str(Path(path) / "config.json") - with open(medusa_config, "r") as f: - medusa_config = json.load(f) - routing = weights.routing - with safe_open(filename, framework="pytorch") as f: - for k in f.keys(): - if k in routing and routing[k] != filename: - raise RuntimeError( - f"Key {k} was found in multiple files: {filename} and {routing[k]}" - ) - routing[k] = filename + for fname in speculator["model_paths"]: + filename = str(Path(path) / fname) + + with open(medusa_config, "r") as f: + medusa_config = json.load(f) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename medusa = MedusaModel(config, medusa_config, weights) lm_head = TensorParallelHead.load(config, prefix, weights) @@ -108,10 +111,10 @@ class MedusaHeadV2(nn.Module): from safetensors import safe_open import json - use_medusa = config.use_medusa + speculator = config.speculator - medusa_config = str(Path(use_medusa) / "config.json") - filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + medusa_config = str(Path(speculator) / "config.json") + filename = str(Path(speculator) / "medusa_lm_head.safetensors") with open(medusa_config, "r") as f: medusa_config = json.load(f) diff --git a/server/text_generation_server/layers/mlp.py b/server/text_generation_server/layers/mlp.py new file mode 100644 index 00000000..f08cb673 --- /dev/null +++ b/server/text_generation_server/layers/mlp.py @@ -0,0 +1,176 @@ +import torch +import math +from torch import nn +from torch.nn import functional as F +from typing import Optional, Tuple +from text_generation_server.layers import TensorParallelEmbedding, FastLinear +from text_generation_server.layers.tensor_parallel import TensorParallelHead +from text_generation_server.utils.speculate import get_speculate + + +class MLPSpeculatorLayerNorm(nn.Module): + """ + A L2 normalization implementation + ... + Args + ---- + normalized_shape : int + Dimensionality of input data (size of final tensor axis) + elementwise_scale_weight : torch.Tensor + learned scaling term after normalization? + elementwise_shift_bias : torch.Tensor + learned bias term after normalization? + eps : float + Safety term to prevent division by zero. Make sure the chosen value fits in the range of your encoding scheme (i.e. fp16 requires eps >= 6e-8). + """ + + def __init__( + self, + prefix, + config, + weights, + eps=1e-06, + ): + super(MLPSpeculatorLayerNorm, self).__init__() + self.weight = weights.get_tensor(f"{prefix}.weight") + self.bias = weights.get_tensor(f"{prefix}.bias") + self.eps = eps + + def forward(self, x): + xf = x + xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps) + x = xf.type_as(x) + x = self.weight * x + x = x + self.bias + return x + + +class MLPSpeculatorModel(torch.nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.config = config + self.n_predict = get_speculate() + self.hidden_size = config.hidden_size + self.emb = nn.ModuleList( + [ + TensorParallelEmbedding(f"{prefix}.emb.{i}", weights) + for i in range(self.n_predict) + ] + ) + self.proj = [ + FastLinear.load( + config, + prefix=f"{prefix}.proj.{i}", + weights=weights, + bias=False, + ) + for i in range(self.n_predict) + ] + self.head = nn.ModuleList( + [ + FastLinear.load(config, f"{prefix}.head.{i}", weights, bias=False) + for i in range(self.n_predict) + ] + ) + self.ln = nn.ModuleList( + [ + MLPSpeculatorLayerNorm( + prefix=f"{prefix}.ln.{i}", + config=config, + weights=weights, + ) + for i in range(self.n_predict) + ] + ) + + # Weights ensure that state_0 accounts for 50% of state magnitude by final head in expectation + self.state_weight = 0.5 ** (0.5 / self.n_predict) + self.emb_weight = math.sqrt(1 - self.state_weight**2) + self.activation = nn.GELU() + # TODO + self.vsize = config.vocab_size + self.inner_dim = config.speculator_config["inner_dim"] + self.top_k_tokens_per_head = [1] * self.n_predict + + def forward( + self, + hidden_states: torch.Tensor, + input_ids: torch.Tensor, + ): + top_k_tokens_per_head = self.top_k_tokens_per_head + + # k indicates # of candidates + # h indicates # of generated tokens + state = hidden_states + b = state.size(0) + ind = input_ids.unsqueeze(0) + all_probs = torch.empty( + b, self.n_predict, self.vsize, device=state.device + ) # b k h v + assert ( + len(top_k_tokens_per_head) == self.n_predict + ), f"You must provide a topk number for each head ({self.n_predict} heads, {len(top_k_tokens_per_head)} provided)" + for i in range(self.n_predict): + # Project and predict + z = self.emb[i](ind) + z = z.mul(self.emb_weight * math.sqrt(self.inner_dim / 2)) # b k d + state = self.proj[i](state) * self.state_weight + z + state = self.activation(self.ln[i](state)) # b k d + probs = F.log_softmax(self.head[i](state), dim=-1) # b k v + _probs, preds = probs.topk(top_k_tokens_per_head[i], dim=-1) # b k k' + + # Update candidate set with new predictions + + # Update distribution set with new logits + all_probs[:, i] = probs.exp() + + # Update state, log_probs and ind for new predictions + state = state.unsqueeze(2).expand( + -1, -1, top_k_tokens_per_head[i], -1 + ) # b k k' d + state = state.reshape(-1, b, state.size(3)) # b kk' d + ind = preds.view(-1, b) # b kk' + + speculative_logits = all_probs + return speculative_logits + + +class MLPSpeculatorHead(nn.Module): + def __init__(self, lm_head, mlp_speculator): + super().__init__() + self.lm_head = lm_head + self.mlp_speculator = mlp_speculator + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self.lm_head(input) + # If we have too many tokens, we skip speculative logits + if input.shape[0] > 128: + return logits, None + + input_ids = logits.argmax(dim=-1) + speculative_logits = self.mlp_speculator(input, input_ids) + return logits, speculative_logits + + @staticmethod + def load(config, prefix: str, weights): + from pathlib import Path + from safetensors import safe_open + + speculator_path = config.speculator["path"] + + for fname in config.speculator["model_paths"]: + filename = str(Path(speculator_path) / fname) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + + mlp_speculator = MLPSpeculatorModel(config, "speculator", weights) + lm_head = TensorParallelHead.load(config, prefix, weights) + return MLPSpeculatorHead(lm_head, mlp_speculator) diff --git a/server/text_generation_server/layers/speculative.py b/server/text_generation_server/layers/speculative.py index 663f8c2e..4b977a56 100644 --- a/server/text_generation_server/layers/speculative.py +++ b/server/text_generation_server/layers/speculative.py @@ -1,34 +1,51 @@ import torch +import json from typing import Tuple, Optional -from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2 from text_generation_server.layers.tensor_parallel import TensorParallelHead +from text_generation_server.layers.medusa import MedusaHeadV1, MedusaHeadV2 +from text_generation_server.layers.mlp import MLPSpeculatorHead class SpeculativeHead(torch.nn.Module): - def __init__(self, lm_head, medusa): + def __init__(self, lm_head, speculator): super().__init__() self.head = lm_head - self.medusa = medusa + self.speculator = speculator @staticmethod def load(config, prefix: str, weights): - use_medusa = config.use_medusa - if use_medusa: - lm_head = None + speculator = config.speculator + if speculator: + speculator_path = config.speculator["path"] + speculator_config = str(speculator_path / "config.json") + + with open(speculator_config, "r") as f: + speculator_config = json.load(f) + + config.speculator_config = speculator_config try: - medusa = MedusaHeadV1.load(config, prefix, weights) - except: - medusa = MedusaHeadV2(config, prefix, weights) + architecture = speculator_config["architectures"][0] + + if architecture == "MLPSpeculatorPreTrainedModel": + speculator = MLPSpeculatorHead.load(config, prefix, weights) + else: + speculator = None + except KeyError: + try: + speculator = MedusaHeadV1.load(config, prefix, weights) + except: + speculator = MedusaHeadV2(config, prefix, weights) + lm_head = None else: lm_head = TensorParallelHead.load(config, prefix, weights) - medusa = None - return SpeculativeHead(lm_head, medusa) + speculator = None + return SpeculativeHead(lm_head, speculator) def forward( self, input: torch.Tensor ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - if self.medusa is not None: - return self.medusa(input) + if self.speculator is not None: + return self.speculator(input) assert self.head is not None logits = self.head(input) diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index b52765d7..e9761dfe 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -1,9 +1,10 @@ import torch +import os from loguru import logger from transformers.configuration_utils import PretrainedConfig from transformers.models.auto import modeling_auto -from huggingface_hub import hf_hub_download +from huggingface_hub import hf_hub_download, HfApi from typing import Optional from pathlib import Path @@ -135,8 +136,9 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + model_type = config_dict.get("model_type", None) - use_medusa = None + speculator = None if "medusa_num_heads" in config_dict: medusa_model_id = model_id medusa_revision = revision @@ -156,6 +158,8 @@ def get_model( config_dict, _ = PretrainedConfig.get_config_dict( model_id, revision=revision, trust_remote_code=trust_remote_code ) + # Reload model type from parent. + model_type = config_dict.get("model_type", None) is_local = Path(medusa_model_id).exists() if not is_local: medusa_config = hf_hub_download( @@ -166,11 +170,70 @@ def get_model( revision=medusa_revision, filename="medusa_lm_head.safetensors", ) - use_medusa = Path(medusa_config).parent + speculator = { + "path": Path(medusa_config).parent, + "model_paths": ["medusa_lm_head.safetensors"], + } else: - use_medusa = Path(medusa_model_id) + speculator = { + "path": Path(medusa_model_id), + "model_paths": ["medusa_lm_head.safetensors"], + } method = "medusa" + elif model_type == "mlp_speculator": + mlp_model_id = model_id + mlp_revision = revision + model_id = config_dict["base_model_name_or_path"] + revision = "main" + speculate_mlp = config_dict["n_predict"] + if speculate is not None: + if speculate > speculate_mlp: + raise RuntimeError( + f"Speculate is set to `{speculate}` but this mlp_speculator models only has `{speculate_mlp}` heads, please make them match" + ) + else: + set_speculate(speculate) + else: + set_speculate(speculate_mlp) + + config_dict, _ = PretrainedConfig.get_config_dict( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + # Reload model type from parent. + model_type = config_dict.get("model_type", None) + is_local = Path(mlp_model_id).exists() + extension = ".safetensors" + if not is_local: + mlp_speculator_config = hf_hub_download( + mlp_model_id, revision=mlp_revision, filename="config.json" + ) + api = HfApi() + info = api.model_info(mlp_model_id, revision=mlp_revision) + filenames = [ + s.rfilename + for s in info.siblings + if s.rfilename.endswith(extension) + and len(s.rfilename.split("/")) == 1 + and "arguments" not in s.rfilename + and "args" not in s.rfilename + and "training" not in s.rfilename + ] + for filename in filenames: + hf_hub_download( + mlp_model_id, + revision=mlp_revision, + filename=filename, + ) + speculator = { + "path": Path(mlp_speculator_config).parent, + "model_paths": filenames, + } + else: + speculator = Path(mlp_model_id) + filenames = [p for p in os.listdir(speculator) if p.endswith(extension)] + speculator = {"path": speculator, "model_paths": filenames} + method = "mlp_speculator" else: method = "n-gram" @@ -178,7 +241,6 @@ def get_model( if speculate > 0: logger.info(f"Using speculation {method} with {speculate} input ids.") - model_type = config_dict.get("model_type", None) if model_type is None: # TODO: fix how we determine model type for Mamba if "ssm_cfg" in config_dict: @@ -202,7 +264,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -212,7 +274,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -227,7 +289,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -240,7 +302,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -250,7 +312,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -259,7 +321,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -270,7 +332,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -279,7 +341,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -288,7 +350,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -299,7 +361,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -308,7 +370,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -323,7 +385,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -334,7 +396,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -345,7 +407,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -355,7 +417,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -366,7 +428,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -377,7 +439,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -388,7 +450,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -399,7 +461,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -410,7 +472,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -424,7 +486,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -435,7 +497,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -444,7 +506,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -458,7 +520,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -469,7 +531,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -483,7 +545,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -494,7 +556,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -520,7 +582,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -544,7 +606,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -554,7 +616,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -564,7 +626,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -574,7 +636,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -586,7 +648,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -599,7 +661,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -623,7 +685,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -632,7 +694,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -644,7 +706,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) @@ -653,7 +715,7 @@ def get_model( model_id, revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/bloom.py b/server/text_generation_server/models/bloom.py index 67129ec3..1e3dd10c 100644 --- a/server/text_generation_server/models/bloom.py +++ b/server/text_generation_server/models/bloom.py @@ -42,7 +42,7 @@ class BLOOMSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -71,7 +71,7 @@ class BLOOMSharded(CausalLM): ) config.pad_token_id = 3 config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 93ec6ba4..81a02163 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -482,12 +482,12 @@ class CausalLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - if use_medusa: - raise RuntimeError("Medusa decoding is not enabled for AutoModel") + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") if torch.cuda.is_available(): device = torch.device("cuda") diff --git a/server/text_generation_server/models/custom_modeling/idefics2.py b/server/text_generation_server/models/custom_modeling/idefics2.py index 935f049b..51fd7c02 100644 --- a/server/text_generation_server/models/custom_modeling/idefics2.py +++ b/server/text_generation_server/models/custom_modeling/idefics2.py @@ -683,9 +683,9 @@ class Idefics2ForConditionalGeneration(nn.Module): def __init__(self, prefix, config, weights): super().__init__() config.vision_config.quantize = config.quantize - config.vision_config.use_medusa = config.use_medusa + config.vision_config.speculator = config.speculator config.text_config.quantize = config.quantize - config.text_config.use_medusa = config.use_medusa + config.text_config.speculator = config.speculator vision_config = config.vision_config self.text_model = load_text_model( diff --git a/server/text_generation_server/models/custom_modeling/llava_next.py b/server/text_generation_server/models/custom_modeling/llava_next.py index a049f756..de9673aa 100644 --- a/server/text_generation_server/models/custom_modeling/llava_next.py +++ b/server/text_generation_server/models/custom_modeling/llava_next.py @@ -135,7 +135,7 @@ class LlavaNextForConditionalGeneration(nn.Module): self.vocab_size = config.text_config.vocab_size self.config = config config.text_config.quantize = config.quantize - config.text_config.use_medusa = config.use_medusa + config.text_config.speculator = config.speculator self.language_model = load_text_model( prefix="language_model" if not prefix else f"{prefix}.language_model", config=config.text_config, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index f567bea9..5aa7a568 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -1101,6 +1101,8 @@ class FlashCausalLM(Model): next_token_texts = [] left = 0 + logger.info(f"Accepted ids {n_accepted_ids}") + current_stopped = False for j in range(index, index + n_accepted_ids): # Generated token diff --git a/server/text_generation_server/models/flash_cohere.py b/server/text_generation_server/models/flash_cohere.py index f85c7722..b907ee08 100644 --- a/server/text_generation_server/models/flash_cohere.py +++ b/server/text_generation_server/models/flash_cohere.py @@ -24,7 +24,7 @@ class FlashCohere(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -49,7 +49,7 @@ class FlashCohere(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_dbrx.py b/server/text_generation_server/models/flash_dbrx.py index 367d3db0..d5eb1a6e 100644 --- a/server/text_generation_server/models/flash_dbrx.py +++ b/server/text_generation_server/models/flash_dbrx.py @@ -26,7 +26,7 @@ class FlashDbrx(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -74,7 +74,7 @@ class FlashDbrx(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_gemma.py b/server/text_generation_server/models/flash_gemma.py index 7259b820..9c00a056 100644 --- a/server/text_generation_server/models/flash_gemma.py +++ b/server/text_generation_server/models/flash_gemma.py @@ -25,7 +25,7 @@ class FlashGemma(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -50,7 +50,7 @@ class FlashGemma(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index 8ea70713..796fbd47 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -27,7 +27,7 @@ class FlashLlama(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -71,7 +71,7 @@ class FlashLlama(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 48304ad8..b83f49a4 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -313,7 +313,7 @@ class BaseFlashMistral(FlashCausalLM): config_cls=AutoConfig, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, tokenizer_class=AutoTokenizer, @@ -340,7 +340,7 @@ class BaseFlashMistral(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator # Set context windows if getattr(config, "sliding_window", None) is not None: @@ -567,7 +567,7 @@ class FlashMistral(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -577,7 +577,7 @@ class FlashMistral(BaseFlashMistral): model_id=model_id, revision=revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/flash_mixtral.py b/server/text_generation_server/models/flash_mixtral.py index 2ee35e82..587d423f 100644 --- a/server/text_generation_server/models/flash_mixtral.py +++ b/server/text_generation_server/models/flash_mixtral.py @@ -15,7 +15,7 @@ class FlashMixtral(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -25,7 +25,7 @@ class FlashMixtral(BaseFlashMistral): model_id=model_id, revision=revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 1119bdae..adefaeb2 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -25,7 +25,7 @@ class FlashNeoXSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,7 +51,7 @@ class FlashNeoXSharded(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/flash_phi.py b/server/text_generation_server/models/flash_phi.py index cb55f9e6..32b573a9 100644 --- a/server/text_generation_server/models/flash_phi.py +++ b/server/text_generation_server/models/flash_phi.py @@ -25,7 +25,7 @@ class FlashPhi(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -48,7 +48,7 @@ class FlashPhi(FlashCausalLM): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) @@ -58,7 +58,7 @@ class FlashPhi(FlashCausalLM): weights._set_gptq_params(model_id, revision) model = FlashPhiForCausalLM(config, weights) - if use_medusa: + if speculator: from text_generation_server.utils.medusa import MedusaModel from huggingface_hub import hf_hub_download import json @@ -66,19 +66,19 @@ class FlashPhi(FlashCausalLM): from pathlib import Path is_local_model = ( - Path(use_medusa).exists() and Path(use_medusa).is_dir() + Path(speculator).exists() and Path(speculator).is_dir() ) or os.getenv("WEIGHTS_CACHE_OVERRIDE", None) is not None if not is_local_model: medusa_config = hf_hub_download( - use_medusa, revision=revision, filename="config.json" + speculator, revision=revision, filename="config.json" ) medusa_head = hf_hub_download( - use_medusa, revision=revision, filename="medusa_lm_head.pt" + speculator, revision=revision, filename="medusa_lm_head.pt" ) else: - medusa_config = str(Path(use_medusa) / "config.json") - medusa_head = str(Path(use_medusa) / "medusa_lm_head.pt") + medusa_config = str(Path(speculator) / "config.json") + medusa_head = str(Path(speculator) / "medusa_lm_head.pt") with open(medusa_config, "r") as f: config = json.load(f) diff --git a/server/text_generation_server/models/flash_qwen2.py b/server/text_generation_server/models/flash_qwen2.py index cb3cf6b0..59064b30 100644 --- a/server/text_generation_server/models/flash_qwen2.py +++ b/server/text_generation_server/models/flash_qwen2.py @@ -30,7 +30,7 @@ class FlashQwen2(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -53,7 +53,7 @@ class FlashQwen2(BaseFlashMistral): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator # Set context windows if config.sliding_window is not None: diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 33298e1a..e6350611 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -26,7 +26,7 @@ class FlashRWSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -66,7 +66,7 @@ class FlashRWSharded(FlashCausalLM): ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator if config.quantize == "gptq": weights._set_gptq_params(model_id, revision) diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 66698a3a..2ad36b93 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -29,7 +29,7 @@ class FlashSantacoderSharded(FlashCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -57,7 +57,7 @@ class FlashSantacoderSharded(FlashCausalLM): trust_remote_code=True, ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator config.transpose = config.architectures[0].startswith("GPT2") torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/flash_starcoder2.py b/server/text_generation_server/models/flash_starcoder2.py index 68e726d8..dc5d49be 100644 --- a/server/text_generation_server/models/flash_starcoder2.py +++ b/server/text_generation_server/models/flash_starcoder2.py @@ -29,7 +29,7 @@ class FlashStarcoder2(BaseFlashMistral): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -52,7 +52,7 @@ class FlashStarcoder2(BaseFlashMistral): model_id, revision=revision, trust_remote_code=trust_remote_code ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator # Set context windows if config.sliding_window is not None: diff --git a/server/text_generation_server/models/galactica.py b/server/text_generation_server/models/galactica.py index a46f86be..4656fd45 100644 --- a/server/text_generation_server/models/galactica.py +++ b/server/text_generation_server/models/galactica.py @@ -167,7 +167,7 @@ class GalacticaSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -195,7 +195,7 @@ class GalacticaSharded(CausalLM): ) config.quantize = quantize tokenizer.pad_token_id = config.pad_token_id - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/gpt_neox.py b/server/text_generation_server/models/gpt_neox.py index 1c4cfe7d..c0e1adf2 100644 --- a/server/text_generation_server/models/gpt_neox.py +++ b/server/text_generation_server/models/gpt_neox.py @@ -24,7 +24,7 @@ class GPTNeoxSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -51,7 +51,7 @@ class GPTNeoxSharded(CausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") diff --git a/server/text_generation_server/models/idefics.py b/server/text_generation_server/models/idefics.py index 30bf4aa6..c1fe03e4 100644 --- a/server/text_generation_server/models/idefics.py +++ b/server/text_generation_server/models/idefics.py @@ -31,7 +31,7 @@ class IDEFICSSharded(IdeficsCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -52,7 +52,7 @@ class IDEFICSSharded(IdeficsCausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator config.vision_config.quantize = quantize tokenizer = LlamaTokenizerFast.from_pretrained( diff --git a/server/text_generation_server/models/idefics2.py b/server/text_generation_server/models/idefics2.py index e831af89..314c0500 100644 --- a/server/text_generation_server/models/idefics2.py +++ b/server/text_generation_server/models/idefics2.py @@ -18,7 +18,7 @@ class Idefics2(VlmCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -35,7 +35,7 @@ class Idefics2(VlmCausalLM): model_id=model_id, revision=revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/llava_next.py b/server/text_generation_server/models/llava_next.py index 3983bc85..effe8b91 100644 --- a/server/text_generation_server/models/llava_next.py +++ b/server/text_generation_server/models/llava_next.py @@ -18,7 +18,7 @@ class LlavaNext(VlmCausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -30,7 +30,7 @@ class LlavaNext(VlmCausalLM): model_id=model_id, revision=revision, quantize=quantize, - use_medusa=use_medusa, + speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, ) diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 0884317e..b28b744f 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -408,7 +408,7 @@ class Mamba(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -445,7 +445,7 @@ class Mamba(Model): tokenizer.pad_token = tokenizer.eos_token config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) diff --git a/server/text_generation_server/models/mpt.py b/server/text_generation_server/models/mpt.py index 6b3f29a6..8d8b4909 100644 --- a/server/text_generation_server/models/mpt.py +++ b/server/text_generation_server/models/mpt.py @@ -43,7 +43,7 @@ class MPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -76,7 +76,7 @@ class MPTSharded(CausalLM): config = json.load(f) config = PretrainedConfig(**config) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/opt.py b/server/text_generation_server/models/opt.py index 703e5b58..5b84f4ff 100644 --- a/server/text_generation_server/models/opt.py +++ b/server/text_generation_server/models/opt.py @@ -22,7 +22,7 @@ class OPTSharded(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -48,7 +48,7 @@ class OPTSharded(CausalLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator tokenizer.pad_token_id = config.pad_token_id torch.distributed.barrier(group=self.process_group) diff --git a/server/text_generation_server/models/phi.py b/server/text_generation_server/models/phi.py index cc4e2505..d68866c1 100644 --- a/server/text_generation_server/models/phi.py +++ b/server/text_generation_server/models/phi.py @@ -22,7 +22,7 @@ class Phi(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -53,7 +53,7 @@ class Phi(CausalLM): tokenizer.pad_token = tokenizer.eos_token config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator torch.distributed.barrier(group=self.process_group) filenames = weight_files(model_id, revision=revision, extension=".safetensors") weights = Weights(filenames, device, dtype, process_group=self.process_group) diff --git a/server/text_generation_server/models/rw.py b/server/text_generation_server/models/rw.py index 92c93542..d4764ded 100644 --- a/server/text_generation_server/models/rw.py +++ b/server/text_generation_server/models/rw.py @@ -12,11 +12,11 @@ class RW(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - if use_medusa: + if speculator: raise RuntimeError("Medusa decoding is not enabled for AutoModel") if torch.cuda.is_available(): diff --git a/server/text_generation_server/models/santacoder.py b/server/text_generation_server/models/santacoder.py index 73c21cce..323e4324 100644 --- a/server/text_generation_server/models/santacoder.py +++ b/server/text_generation_server/models/santacoder.py @@ -19,7 +19,7 @@ class SantaCoder(CausalLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): diff --git a/server/text_generation_server/models/seq2seq_lm.py b/server/text_generation_server/models/seq2seq_lm.py index e55a661c..6a0c812f 100644 --- a/server/text_generation_server/models/seq2seq_lm.py +++ b/server/text_generation_server/models/seq2seq_lm.py @@ -532,12 +532,12 @@ class Seq2SeqLM(Model): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): - if use_medusa: - raise RuntimeError("Medusa decoding is not enabled for AutoModel") + if speculator: + raise RuntimeError("Speculator decoding is not enabled for AutoModel") if torch.cuda.is_available(): device = torch.device("cuda") diff --git a/server/text_generation_server/models/t5.py b/server/text_generation_server/models/t5.py index 3f3cb965..8e0735e5 100644 --- a/server/text_generation_server/models/t5.py +++ b/server/text_generation_server/models/t5.py @@ -25,7 +25,7 @@ class T5Sharded(Seq2SeqLM): model_id: str, revision: Optional[str] = None, quantize: Optional[str] = None, - use_medusa: Optional[str] = None, + speculator: Optional[str] = None, dtype: Optional[torch.dtype] = None, trust_remote_code: bool = False, ): @@ -43,7 +43,7 @@ class T5Sharded(Seq2SeqLM): trust_remote_code=trust_remote_code, ) config.quantize = quantize - config.use_medusa = use_medusa + config.speculator = speculator tokenizer = AutoTokenizer.from_pretrained( model_id,