This commit is contained in:
OlivierDehaene 2024-04-15 17:55:45 -05:00 committed by GitHub
commit 4a7fcef495
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 2 deletions

View File

@ -8,6 +8,8 @@ from typing import List, Tuple, Optional
from loguru import logger
from functools import lru_cache
from text_generation_server.utils.speculate import get_speculate
HAS_BITS_AND_BYTES = True
try:
import bitsandbytes as bnb
@ -437,7 +439,7 @@ class MedusaModel(torch.nn.Module):
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(medusa_config["medusa_num_heads"])
for i in range(get_speculate())
]
)
@ -534,7 +536,7 @@ class MedusaHeadV2(nn.Module):
)
routing[k] = filename
self.n_medusa_heads = medusa_config["medusa_num_heads"]
self.n_medusa_heads = get_speculate()
assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi(