diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index 8c46ea49..6e4a13cd 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -8,7 +8,8 @@ from typing import List, Tuple, Optional from loguru import logger from functools import lru_cache -# Dummy comment. +from text_generation_server.utils.speculate import get_speculate + HAS_BITS_AND_BYTES = True try: import bitsandbytes as bnb @@ -445,7 +446,7 @@ class MedusaModel(torch.nn.Module): self.heads = torch.nn.ModuleList( [ MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights) - for i in range(medusa_config["medusa_num_heads"]) + for i in range(get_speculate()) ] ) @@ -542,7 +543,7 @@ class MedusaHeadV2(nn.Module): ) routing[k] = filename - self.n_medusa_heads = medusa_config["medusa_num_heads"] + self.n_medusa_heads = get_speculate() assert medusa_config["medusa_num_layers"] == 1 self.linear = TensorParallelColumnLinear.load_multi(