add 'mamba' as model config
This commit is contained in:
parent
47447ef017
commit
060b2db0df
|
@ -148,6 +148,7 @@ pub enum Config {
|
|||
Idefics,
|
||||
Idefics2(Idefics2),
|
||||
Ssm,
|
||||
Mamba,
|
||||
GptBigcode,
|
||||
Santacoder,
|
||||
Bloom,
|
||||
|
|
|
@ -210,7 +210,7 @@ class ModelType(enum.Enum):
|
|||
"url": "https://huggingface.co/databricks/dbrx-instruct",
|
||||
}
|
||||
MAMBA = {
|
||||
"type": "ssm",
|
||||
"type": ["ssm", "mamba"],
|
||||
"name": "Mamba",
|
||||
"url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj",
|
||||
}
|
||||
|
@ -526,7 +526,7 @@ def get_model(
|
|||
dtype=dtype,
|
||||
trust_remote_code=trust_remote_code,
|
||||
)
|
||||
elif model_type == MAMBA:
|
||||
elif model_type in MAMBA:
|
||||
return Mamba(
|
||||
model_id,
|
||||
revision,
|
||||
|
|
Loading…
Reference in New Issue