fix: use get_speculate to the number of layers (#1737)

This commit is contained in:
OlivierDehaene 2024-04-30 11:45:26 +02:00 committed by GitHub
parent 743ecbca3a
commit 8332fc4908
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 3 deletions

View File

@ -8,7 +8,8 @@ from typing import List, Tuple, Optional
from loguru import logger from loguru import logger
from functools import lru_cache from functools import lru_cache
# Dummy comment. from text_generation_server.utils.speculate import get_speculate
HAS_BITS_AND_BYTES = True HAS_BITS_AND_BYTES = True
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
@ -445,7 +446,7 @@ class MedusaModel(torch.nn.Module):
self.heads = torch.nn.ModuleList( self.heads = torch.nn.ModuleList(
[ [
MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights) MedusaHead(config, medusa_config, prefix=f"{i}", weights=weights)
for i in range(medusa_config["medusa_num_heads"]) for i in range(get_speculate())
] ]
) )
@ -542,7 +543,7 @@ class MedusaHeadV2(nn.Module):
) )
routing[k] = filename routing[k] = filename
self.n_medusa_heads = medusa_config["medusa_num_heads"] self.n_medusa_heads = get_speculate()
assert medusa_config["medusa_num_layers"] == 1 assert medusa_config["medusa_num_layers"] == 1
self.linear = TensorParallelColumnLinear.load_multi( self.linear = TensorParallelColumnLinear.load_multi(