add 'mamba' as model config

This commit is contained in:
erikkaum 2024-08-01 18:16:32 +02:00
parent 47447ef017
commit 060b2db0df
2 changed files with 3 additions and 2 deletions

View File

@ -148,6 +148,7 @@ pub enum Config {
Idefics,
Idefics2(Idefics2),
Ssm,
Mamba,
GptBigcode,
Santacoder,
Bloom,

View File

@ -210,7 +210,7 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/databricks/dbrx-instruct",
}
MAMBA = {
"type": "ssm",
"type": ["ssm", "mamba"],
"name": "Mamba",
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
}
@ -526,7 +526,7 @@ def get_model(
dtype=dtype,
trust_remote_code=trust_remote_code,
)
elif model_type == MAMBA:
elif model_type in MAMBA:
return Mamba(
model_id,
revision,