fix: use get_speculate to the number of layers

This commit is contained in:
OlivierDehaene 2024-04-12 20:21:52 +02:00
parent c38a7d7ddd
commit 0c4a634640
1 changed files with 4 additions and 2 deletions

View File

@ -8,6 +8,8 @@ from typing import List, Tuple, Optional
from loguru import logger
from functools import lru_cache
from text_generation_server.utils.speculate import get_speculate
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
@ -437,7 +439,7 @@ class MedusaModel(torch.nn.Module):
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(medusa_config["medusa_num_heads"])
for i in range(get_speculate())
]
)
@ -534,7 +536,7 @@ class MedusaHeadV2(nn.Module):
)
routing[k] = filename
self.n_medusa_heads = medusa_config["medusa_num_heads"]
self.n_medusa_heads = get_speculate()
assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi(