add 'mamba' as model config
This commit is contained in:
parent
47447ef017
commit
060b2db0df
|
@ -148,6 +148,7 @@ pub enum Config {
|
||||||
Idefics,
|
Idefics,
|
||||||
Idefics2(Idefics2),
|
Idefics2(Idefics2),
|
||||||
Ssm,
|
Ssm,
|
||||||
|
Mamba,
|
||||||
GptBigcode,
|
GptBigcode,
|
||||||
Santacoder,
|
Santacoder,
|
||||||
Bloom,
|
Bloom,
|
||||||
|
|
|
@ -210,7 +210,7 @@ class ModelType(enum.Enum):
|
||||||
"url": "https://huggingface.co/databricks/dbrx-instruct",
|
"url": "https://huggingface.co/databricks/dbrx-instruct",
|
||||||
}
|
}
|
||||||
MAMBA = {
|
MAMBA = {
|
||||||
"type": "ssm",
|
"type": ["ssm", "mamba"],
|
||||||
"name": "Mamba",
|
"name": "Mamba",
|
||||||
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
|
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
|
||||||
}
|
}
|
||||||
|
@ -526,7 +526,7 @@ def get_model(
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
)
|
)
|
||||||
elif model_type == MAMBA:
|
elif model_type in MAMBA:
|
||||||
return Mamba(
|
return Mamba(
|
||||||
model_id,
|
model_id,
|
||||||
revision,
|
revision,
|
||||||
|
|
Loading…
Reference in New Issue