fix: use get_speculate to the number of layers (#1737)
This commit is contained in:
parent
743ecbca3a
commit
8332fc4908
|
@ -8,7 +8,8 @@ from typing import List, Tuple, Optional
|
|||
from loguru import logger
|
||||
from functools import lru_cache
|
||||
|
||||
# Dummy comment.
|
||||
from text_generation_server.utils.speculate import get_speculate
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
@ -445,7 +446,7 @@ class MedusaModel(torch.nn.Module):
|
|||
self.heads = torch.nn.ModuleList(
|
||||
[
|
||||
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||
for i in range(medusa_config["medusa_num_heads"])
|
||||
for i in range(get_speculate())
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -542,7 +543,7 @@ class MedusaHeadV2(nn.Module):
|
|||
)
|
||||
routing[k] = filename
|
||||
|
||||
self.n_medusa_heads = medusa_config["medusa_num_heads"]
|
||||
self.n_medusa_heads = get_speculate()
|
||||
|
||||
assert medusa_config["medusa_num_layers"] == 1
|
||||
self.linear = TensorParallelColumnLinear.load_multi(
|
||||
|
|
Loading…
Reference in New Issue