add 'mamba' as model config

This commit is contained in:
erikkaum 2024-08-01 18:16:32 +02:00
parent 47447ef017
commit 060b2db0df
2 changed files with 3 additions and 2 deletions

View File

@ -148,6 +148,7 @@ pub enum Config {
Idefics, Idefics,
Idefics2(Idefics2), Idefics2(Idefics2),
Ssm, Ssm,
Mamba,
GptBigcode, GptBigcode,
Santacoder, Santacoder,
Bloom, Bloom,

View File

@ -210,7 +210,7 @@ class ModelType(enum.Enum):
"url": "https://huggingface.co/databricks/dbrx-instruct", "url": "https://huggingface.co/databricks/dbrx-instruct",
} }
MAMBA = { MAMBA = {
"type": "ssm", "type": ["ssm", "mamba"],
"name": "Mamba", "name": "Mamba",
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
} }
@ -526,7 +526,7 @@ def get_model(
dtype=dtype, dtype=dtype,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
elif model_type == MAMBA: elif model_type in MAMBA:
return Mamba( return Mamba(
model_id, model_id,
revision, revision,