fix: fix logic if sliding window key is not present in config (#1352)
This commit is contained in:
parent
9b56d3fbf5
commit
1b1bfa49b0
|
@ -281,9 +281,10 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == "mistral":
|
||||
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
|
||||
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
|
||||
):
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||
return FlashMistral(
|
||||
model_id,
|
||||
revision,
|
||||
|
@ -293,9 +294,10 @@ def get_model(
|
|||
)
|
||||
|
||||
if model_type == "mixtral":
|
||||
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
|
||||
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
|
||||
):
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||
return FlashMixtral(
|
||||
model_id,
|
||||
revision,
|
||||
|
|
|
@ -60,7 +60,7 @@ class MistralConfig(PretrainedConfig):
|
|||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=4096,
|
||||
sliding_window=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
|
|
|
@ -72,7 +72,7 @@ class MixtralConfig(PretrainedConfig):
|
|||
pretraining_tp=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
sliding_window=4096,
|
||||
sliding_window=None,
|
||||
num_experts_per_tok=2,
|
||||
num_local_experts=8,
|
||||
**kwargs,
|
||||
|
|
|
@ -33,7 +33,7 @@ class Model(ABC):
|
|||
self.device = device
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
self.sliding_window = sliding_window
|
||||
self.sliding_window = sliding_window if sliding_window != -1 else None
|
||||
|
||||
if speculate is None:
|
||||
speculate = get_speculate()
|
||||
|
|
Loading…
Reference in New Issue