fix: fix logic if sliding window key is not present in config (#1352)

This commit is contained in:
OlivierDehaene 2023-12-15 14:56:17 +01:00 committed by GitHub
parent 9b56d3fbf5
commit 1b1bfa49b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 11 additions and 9 deletions

View File

@ -281,9 +281,10 @@ def get_model(
)
if model_type == "mistral":
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
):
sliding_window = config_dict.get("sliding_window", -1)
if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
return FlashMistral(
model_id,
revision,
@ -293,9 +294,10 @@ def get_model(
)
if model_type == "mixtral":
if (config_dict["sliding_window"] is None and FLASH_ATTENTION) or (
config_dict["sliding_window"] > 0 and HAS_FLASH_ATTN_V2_CUDA
):
sliding_window = config_dict.get("sliding_window", -1)
if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
) or HAS_FLASH_ATTN_V2_CUDA:
return FlashMixtral(
model_id,
revision,

View File

@ -60,7 +60,7 @@ class MistralConfig(PretrainedConfig):
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
sliding_window=None,
**kwargs,
):
self.vocab_size = vocab_size

View File

@ -72,7 +72,7 @@ class MixtralConfig(PretrainedConfig):
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=10000.0,
sliding_window=4096,
sliding_window=None,
num_experts_per_tok=2,
num_local_experts=8,
**kwargs,

View File

@ -33,7 +33,7 @@ class Model(ABC):
self.device = device
self.rank = rank
self.world_size = world_size
self.sliding_window = sliding_window
self.sliding_window = sliding_window if sliding_window != -1 else None
if speculate is None:
speculate = get_speculate()