diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 00c8b9a8..06792b0d 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -145,7 +145,7 @@ def get_model( if speculate is not None: if speculate > speculate_medusa: raise RuntimeError( - "Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match" + f"Speculate is set to `{speculate}` but this medusa models only has `{speculate_medusa}` heads, please make them match" ) else: set_speculate(speculate) diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 2a9d3914..2c440083 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -814,7 +814,7 @@ class FlashCausalLM(Model): for bs in CUDA_GRAPHS: if self.speculate is None or self.speculate + 1 <= bs: self.cuda_graph_warmup(bs, max_s, max_bt) - except Exception: + except torch.cuda.OutOfMemoryError: logger.exception(f"Decode cuda graph warmup failed") return int(num_blocks * BLOCK_SIZE) @@ -874,22 +874,14 @@ class FlashCausalLM(Model): lm_head_indices = batch.prefill_head_indices bs = input_ids.shape[0] - padded_bs = bs - if bs == 3: - padded_bs = 4 - elif 3 < bs <= 8: - padded_bs = 8 - elif bs > 8: - padded_bs = (bs + 7) // 8 * 8 + sorted_padded_bs = sorted([k for k in self.cuda_graphs.keys() if k >= bs]) + if sorted_padded_bs: + # Get associated cuda graph + cuda_graph = self.cuda_graphs[sorted_padded_bs[0]] + else: + cuda_graph = None - # Try to find an associated cuda graph - cuda_graph = self.cuda_graphs.get(padded_bs, None) - - if ( - cu_seqlen_prefill is not None - or cuda_graph is None - or batch.speculative_ids is not None - ): + if cu_seqlen_prefill is not None or cuda_graph is None: return self.model.forward( input_ids=input_ids, position_ids=position_ids, diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 2b95bc74..2c12984b 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -432,12 +432,12 @@ class ResBlock(torch.nn.Module): class MedusaModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, config, medusa_config, weights): super().__init__() self.heads = torch.nn.ModuleList( [ - MedusaHead(config, prefix=f"{i}", weights=weights) - for i in range(config["medusa_num_heads"]) + MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights) + for i in range(medusa_config["medusa_num_heads"]) ] ) @@ -447,12 +447,12 @@ class MedusaModel(torch.nn.Module): class MedusaHead(torch.nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, medusa_config, prefix, weights): super().__init__() self.blocks = torch.nn.ModuleList( [ ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) - for i in range(config["medusa_num_layers"]) + for i in range(medusa_config["medusa_num_layers"]) ] ) n = len(self.blocks) @@ -467,7 +467,7 @@ class MedusaHead(torch.nn.Module): return x -class SpeculativeHead(nn.Module): +class MedusaHeadV1(nn.Module): def __init__(self, lm_head, medusa): super().__init__() self.lm_head = lm_head @@ -475,38 +475,147 @@ class SpeculativeHead(nn.Module): @staticmethod def load(config, prefix: str, weights): + from pathlib import Path + from safetensors import safe_open + import json + + use_medusa = config.use_medusa + + medusa_config = str(Path(use_medusa) / "config.json") + filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + + with open(medusa_config, "r") as f: + medusa_config = json.load(f) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + + medusa = MedusaModel(config, medusa_config, weights) lm_head = TensorParallelHead.load(config, prefix, weights) + return MedusaHeadV1(lm_head, medusa) + + def forward( + self, input: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + logits = self.lm_head(input) + speculative_logits = self.medusa(input) + return logits, speculative_logits + + +class MedusaHeadV2(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + from pathlib import Path + from safetensors import safe_open + import json + + use_medusa = config.use_medusa + + medusa_config = str(Path(use_medusa) / "config.json") + filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") + + with open(medusa_config, "r") as f: + medusa_config = json.load(f) + routing = weights.routing + with safe_open(filename, framework="pytorch") as f: + for k in f.keys(): + if k in routing and routing[k] != filename: + raise RuntimeError( + f"Key {k} was found in multiple files: {filename} and {routing[k]}" + ) + routing[k] = filename + + self.n_medusa_heads = medusa_config["medusa_num_heads"] + + assert medusa_config["medusa_num_layers"] == 1 + self.linear = TensorParallelColumnLinear.load_multi( + config, + prefixes=[f"{i}.0.linear" for i in range(self.n_medusa_heads)], + dim=0, + weights=weights, + bias=True, + ) + self.process_group = weights.process_group + self.world_size = self.process_group.size() + self.rank = self.process_group.rank() + + self.act = torch.nn.SiLU() + + self.lm_head = TensorParallelHead.load(config, prefix, weights) + + def forward(self, x): + size = x.shape[-1] + block_size = (size + self.world_size - 1) // self.world_size + start = self.rank * block_size + stop = (self.rank + 1) * block_size + + x_block = x[:, start:stop] + + # Compute all medusa heads at the same time, then reshape and move the n_medusa_heads dim to dim 1 + medusa_res = self.act(self.linear(x)).reshape( + *x_block.shape[:-1], self.n_medusa_heads, x_block.shape[-1] + ) + + # Apply all residual medusa heads + output = x[:, start:stop].unsqueeze(-2) + medusa_res + + # Gather medusa heads + world_output = [ + torch.empty_like(output) for _ in range(self.process_group.size()) + ] + torch.distributed.all_gather(world_output, output, group=self.process_group) + world_output = torch.cat(world_output, dim=-1) + + # Stack x and medusa residual x + stacked_x = torch.cat([x.unsqueeze(-2), world_output], dim=-2) + + # Compute lm head on x + medusa residual x + logits = self.lm_head(stacked_x) + + # Finally, split logits from speculative logits + logits, speculative_logits = torch.split( + logits, [1, self.n_medusa_heads], dim=-2 + ) + # Squeeze added dimension + logits = logits.squeeze(-2) + + return logits, speculative_logits + + +class SpeculativeHead(nn.Module): + def __init__(self, lm_head, medusa): + super().__init__() + self.head = lm_head + self.medusa = medusa + + @staticmethod + def load(config, prefix: str, weights): use_medusa = config.use_medusa if use_medusa: - from pathlib import Path - from safetensors import safe_open - import json - - medusa_config = str(Path(use_medusa) / "config.json") - filename = str(Path(use_medusa) / "medusa_lm_head.safetensors") - - with open(medusa_config, "r") as f: - config = json.load(f) - routing = weights.routing - with safe_open(filename, framework="pytorch") as f: - for k in f.keys(): - if k in routing: - raise RuntimeError( - f"Key {k} was found in multiple files: {filename} and {routing[k]}" - ) - weights.routing[k] = filename - - medusa = MedusaModel(config, weights) + lm_head = None + try: + medusa = MedusaHeadV1.load(config, prefix, weights) + except: + medusa = MedusaHeadV2(config, prefix, weights) else: + lm_head = TensorParallelHead.load(config, prefix, weights) medusa = None return SpeculativeHead(lm_head, medusa) def forward( self, input: torch.Tensor ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - logits = self.lm_head(input) - speculative_logits = self.medusa(input) if self.medusa is not None else None - return logits, speculative_logits + if self.medusa is not None: + return self.medusa(input) + + assert self.head is not None + logits = self.head(input) + return logits, None class TensorParallelHead(SuperLayer):