diff --git a/router/src/config.rs b/router/src/config.rs index 7737165e..c7cbd7d0 100644 --- a/router/src/config.rs +++ b/router/src/config.rs @@ -148,6 +148,7 @@ pub enum Config { Idefics, Idefics2(Idefics2), Ssm, + Mamba, GptBigcode, Santacoder, Bloom, diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3dc24159..8046b457 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -210,7 +210,7 @@ class ModelType(enum.Enum): "url": "https://huggingface.co/databricks/dbrx-instruct", } MAMBA = { - "type": "ssm", + "type": ["ssm", "mamba"], "name": "Mamba", "url": "https://huggingface.co/state-spaces/mamba-2.8b-slimpj", } @@ -526,7 +526,7 @@ def get_model( dtype=dtype, trust_remote_code=trust_remote_code, ) - elif model_type == MAMBA: + elif model_type in MAMBA: return Mamba( model_id, revision,