fix: use get_speculate to the number of layers (#1737)
This commit is contained in:
parent
743ecbca3a
commit
8332fc4908
|
@ -8,7 +8,8 @@ from typing import List, Tuple, Optional
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
# Dummy comment.
|
from text_generation_server.utils.speculate import get_speculate
|
||||||
|
|
||||||
HAS_BITS_AND_BYTES = True
|
HAS_BITS_AND_BYTES = True
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
@ -445,7 +446,7 @@ class MedusaModel(torch.nn.Module):
|
||||||
self.heads = torch.nn.ModuleList(
|
self.heads = torch.nn.ModuleList(
|
||||||
[
|
[
|
||||||
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
|
||||||
for i in range(medusa_config["medusa_num_heads"])
|
for i in range(get_speculate())
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -542,7 +543,7 @@ class MedusaHeadV2(nn.Module):
|
||||||
)
|
)
|
||||||
routing[k] = filename
|
routing[k] = filename
|
||||||
|
|
||||||
self.n_medusa_heads = medusa_config["medusa_num_heads"]
|
self.n_medusa_heads = get_speculate()
|
||||||
|
|
||||||
assert medusa_config["medusa_num_layers"] == 1
|
assert medusa_config["medusa_num_layers"] == 1
|
||||||
self.linear = TensorParallelColumnLinear.load_multi(
|
self.linear = TensorParallelColumnLinear.load_multi(
|
||||||
|
|
Loading…
Reference in New Issue