diff --git a/Dockerfile b/Dockerfile index b6c5b2ed..6818005f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -154,6 +154,12 @@ COPY server/Makefile-vllm Makefile # Build specific version of vllm RUN make build-vllm-cuda +# Build mamba kernels +FROM kernel-builder as mamba-builder +WORKDIR /usr/src +COPY server/Makefile-selective-scan Makefile +RUN make build-all + # Build megablocks FROM kernel-builder as megablocks-builder @@ -205,6 +211,10 @@ COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-31 # Copy builds artifacts from vllm builder COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages +# Copy build artifacts from mamba builder +COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages +COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages + # Install flash-attention dependencies RUN pip install einops --no-cache-dir diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json new file mode 100644 index 00000000..4435f215 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba.json @@ -0,0 +1,73 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.3552246, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38378906, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.140625, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.5551758, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59033203, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.70654297, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0410156, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0026435852, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2841797, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" +} diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json new file mode 100644 index 00000000..052c1c69 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_all_params.json @@ -0,0 +1,99 @@ +{ + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 2502, + "logprob": null, + "text": " red" + }, + { + "id": 13, + "logprob": -2.5234375, + "text": "," + }, + { + "id": 8862, + "logprob": -3.4433594, + "text": " yellow" + }, + { + "id": 13, + "logprob": -0.43017578, + "text": "," + }, + { + "id": 209, + "logprob": -8.21875, + "text": " " + } + ], + "seed": 0, + "tokens": [ + { + "id": 187, + "logprob": 0.0, + "special": false, + "text": "\n" + }, + { + "id": 395, + "logprob": -0.46411133, + "special": false, + "text": "and" + }, + { + "id": 13735, + "logprob": -2.1132812, + "special": false, + "text": " orange" + }, + { + "id": 313, + "logprob": -1.2128906, + "special": false, + "text": " (" + }, + { + "id": 249, + "logprob": -2.3671875, + "special": false, + "text": "in" + }, + { + "id": 253, + "logprob": 0.0, + "special": false, + "text": " the" + }, + { + "id": 1340, + "logprob": -1.640625, + "special": false, + "text": " order" + }, + { + "id": 597, + "logprob": -0.5488281, + "special": false, + "text": " they" + }, + { + "id": 3176, + "logprob": -0.48608398, + "special": false, + "text": " appear" + }, + { + "id": 275, + "logprob": 0.0, + "special": false, + "text": " in" + } + ], + "top_tokens": null + }, + "generated_text": "blue, red, yellow, \nand orange (in the order they appear in" +} diff --git a/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json b/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json new file mode 100644 index 00000000..014210b2 --- /dev/null +++ b/integration-tests/models/__snapshots__/test_mamba/test_mamba_load.json @@ -0,0 +1,398 @@ +[ + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.8125, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.828125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -3.0, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1484375, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.3552246, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38378906, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1279297, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.5595703, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.60253906, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7050781, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0488281, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3808594, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0026416779, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.35351562, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38256836, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1269531, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.54541016, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59765625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7001953, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0585938, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0027446747, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.35351562, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38256836, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1269531, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.54541016, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59765625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7001953, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0585938, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0027446747, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + }, + { + "details": { + "best_of_sequences": null, + "finish_reason": "length", + "generated_tokens": 10, + "prefill": [ + { + "id": 1276, + "logprob": null, + "text": "What" + }, + { + "id": 310, + "logprob": -0.78027344, + "text": " is" + }, + { + "id": 18147, + "logprob": -12.8203125, + "text": " Deep" + }, + { + "id": 20727, + "logprob": -2.9902344, + "text": " Learning" + }, + { + "id": 32, + "logprob": -1.1523438, + "text": "?" + } + ], + "seed": null, + "tokens": [ + { + "id": 187, + "logprob": -0.35351562, + "special": false, + "text": "\n" + }, + { + "id": 187, + "logprob": -0.38256836, + "special": false, + "text": "\n" + }, + { + "id": 30763, + "logprob": -1.1269531, + "special": false, + "text": "Deep" + }, + { + "id": 4715, + "logprob": -0.54541016, + "special": false, + "text": " learning" + }, + { + "id": 310, + "logprob": -0.59765625, + "special": false, + "text": " is" + }, + { + "id": 247, + "logprob": -0.7001953, + "special": false, + "text": " a" + }, + { + "id": 747, + "logprob": -2.0585938, + "special": false, + "text": " new" + }, + { + "id": 1511, + "logprob": -2.3789062, + "special": false, + "text": " type" + }, + { + "id": 273, + "logprob": -0.0027446747, + "special": false, + "text": " of" + }, + { + "id": 5145, + "logprob": -1.2851562, + "special": false, + "text": " machine" + } + ], + "top_tokens": null + }, + "generated_text": "\n\nDeep learning is a new type of machine" + } +] diff --git a/integration-tests/models/test_mamba.py b/integration-tests/models/test_mamba.py new file mode 100644 index 00000000..d86faeff --- /dev/null +++ b/integration-tests/models/test_mamba.py @@ -0,0 +1,59 @@ +import pytest + + +@pytest.fixture(scope="module") +def fused_kernel_mamba_handle(launcher): + with launcher("state-spaces/mamba-130m", num_shard=1) as handle: + yield handle + + +@pytest.fixture(scope="module") +async def fused_kernel_mamba(fused_kernel_mamba_handle): + await fused_kernel_mamba_handle.health(300) + return fused_kernel_mamba_handle.client + + +@pytest.mark.asyncio +@pytest.mark.private +async def test_mamba(fused_kernel_mamba, response_snapshot): + response = await fused_kernel_mamba.generate( + "What is Deep Learning?", max_new_tokens=10 + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "\n\nDeep learning is a new type of machine" + assert response == response_snapshot + +@pytest.mark.asyncio +@pytest.mark.private +async def test_mamba_all_params(fused_kernel_mamba, response_snapshot): + response = await fused_kernel_mamba.generate( + "blue, red, yellow, ", + max_new_tokens=10, + repetition_penalty=1.2, + return_full_text=True, + stop_sequences=["test"], + temperature=0.5, + top_p=0.9, + top_k=10, + truncate=5, + typical_p=0.9, + watermark=True, + decoder_input_details=True, + seed=0, + ) + + assert response.details.generated_tokens == 10 + assert response.generated_text == "blue, red, yellow, \nand orange (in the order they appear in" + assert response == response_snapshot + +@pytest.mark.asyncio +@pytest.mark.private +async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): + responses = await generate_load(fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4) + + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert responses[0].generated_text == "\n\nDeep learning is a new type of machine" + + assert responses == response_snapshot diff --git a/server/.gitignore b/server/.gitignore index dcb8fe67..576746ee 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -161,3 +161,4 @@ flash-attention-v2/ vllm/ llm-awq/ eetq/ +mamba/ diff --git a/server/Makefile b/server/Makefile index b1926828..31d55c41 100644 --- a/server/Makefile +++ b/server/Makefile @@ -3,6 +3,7 @@ include Makefile-flash-att-v2 include Makefile-vllm include Makefile-awq include Makefile-eetq +include Makefile-selective-scan unit-tests: pytest -s -vv -m "not private" tests diff --git a/server/Makefile-selective-scan b/server/Makefile-selective-scan new file mode 100644 index 00000000..f4dec868 --- /dev/null +++ b/server/Makefile-selective-scan @@ -0,0 +1,28 @@ +selective_scan_commit := 2a3704fd47ba817b415627b06fd796b971fdc137 + +causal-conv1d: + rm -rf causal-conv1d + git clone https://github.com/Dao-AILab/causal-conv1d.git + +build-causal-conv1d: causal-conv1d + cd causal-conv1d/ && git checkout v1.1.1 # known latest working version tag + cd causal-conv1d/ && CAUSAL_CONV1D_FORCE_BUILD=TRUE python setup.py build + +install-causal-conv1d: build-causal-conv1d + pip uninstall causal-conv1d -y || true + cd causal-conv1d/ && pip install . + +# selective-scan dependends on causal-conv1d +selective-scan: + rm -rf mamba + git clone https://github.com/state-spaces/mamba.git mamba + +build-selective-scan: selective-scan + cd mamba/ && git fetch && git checkout $(selective_scan_commit) + cd mamba && python setup.py build + +install-selective-scan: install-causal-conv1d build-selective-scan + pip uninstall selective-scan-cuda -y || true + cd mamba && pip install . + +build-all: build-causal-conv1d build-selective-scan \ No newline at end of file diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 68096709..a952f060 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -76,6 +76,15 @@ if FLASH_ATTENTION: __all__.append(FlashMixtral) __all__.append(FlashPhi) +MAMBA_AVAILABLE = True +try: + from text_generation_server.models.mamba import Mamba +except ImportError as e: + logger.warning(f"Could not import Mamba: {e}") + MAMBA_AVAILABLE = False + +if MAMBA_AVAILABLE: + __all__.append(Mamba) def get_model( model_id: str, @@ -164,7 +173,25 @@ def get_model( if speculate > 0: logger.info(f"Using speculation {method} with {speculate} input ids.") - model_type = config_dict["model_type"] + model_type = config_dict.get("model_type", None) + if model_type is None: + # TODO: fix how we determine model type for Mamba + if "ssm_cfg" in config_dict: + # *only happens in Mamba case + model_type = "ssm" + else: + raise RuntimeError( + f"Could not determine model type for {model_id} revision {revision}" + ) + + if model_type == "ssm": + return Mamba( + model_id, + revision, + quantize=quantize, + dtype=dtype, + trust_remote_code=trust_remote_code, + ) if model_type == "gpt_bigcode": if FLASH_ATTENTION: diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py new file mode 100644 index 00000000..1773f04d --- /dev/null +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -0,0 +1,194 @@ +import torch +import torch.distributed + +from mamba_ssm.ops.triton.selective_state_update import selective_state_update +from mamba_ssm.ops.selective_scan_interface import selective_scan_fn +from mamba_ssm.utils.generation import InferenceParams +from torch import nn +from typing import Optional, Tuple, Any +from transformers.configuration_utils import PretrainedConfig +import torch.nn.functional as F + +from text_generation_server.utils.layers import ( + TensorParallelEmbedding, + FastRMSNorm, + FastLinear, +) + +from einops import rearrange +from causal_conv1d import causal_conv1d_fn, causal_conv1d_update +import math + +class MambaConfig(PretrainedConfig): + def __init__( + self, + vocab_size=50280, + d_model=768, + d_state=16, + n_layer=32, + layer_norm_epsilon=1e-5, + tie_word_embeddings=False, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + expand=2, + dt_rank="auto", + **kwargs, + ): + self.vocab_size = vocab_size + self.n_layer = n_layer + self.layer_norm_epsilon = layer_norm_epsilon + self.d_model = d_model + self.d_inner = d_model * 2 + self.d_conv = 4 + self.d_state = d_state + self.expand = expand + self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + +class MambaBlock(nn.Module): + def __init__(self, prefix, config, weights): + super().__init__() + self.layer_idx = int(prefix.split(".")[2]) + self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) + self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) + self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) + self.dt_proj_no_bias = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=False) + self.out_proj = FastLinear.load(config, f"{prefix}.out_proj", weights, bias=False) + self.conv1d = FastLinear.load(config, f"{prefix}.conv1d", weights, bias=True) + self.negA = -torch.exp(weights.get_tensor(f"{prefix}.A_log").float()) + self.D = weights.get_tensor(f"{prefix}.D") + self.activation = "silu" + self.dt_rank = config.dt_rank + self.d_state = config.d_state + self.d_conv = config.d_conv + self.act = nn.SiLU() + + # inference_params + def forward(self, hidden_states: torch.Tensor, inference_params=None): + _, seqlen, _ = hidden_states.shape + conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx] + + if inference_params.seqlen_offset > 0: + out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) + return out, conv_state, ssm_state + + projected_states = self.in_proj(hidden_states).transpose(1,2) + x, z = projected_states.chunk(2, dim=1) + conv_state = F.pad(x, (self.d_conv - seqlen, 0)) + x = causal_conv1d_fn( + x=x, + weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + bias=self.conv1d.bias, + activation=self.activation, + ) + + # We're careful here about the layout, to avoid extra transposes. + # We want dt to have d as the slowest moving dimension + # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. + x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d")) # (bl d) + dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = self.dt_proj.weight @ dt.t() + dt = rearrange(dt, "d (b l) -> b d l", l=seqlen) + B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous() + y, last_state = selective_scan_fn( + x, + dt, + self.negA, + B, + C, + self.D.float(), + z=z, + delta_bias=self.dt_proj.bias.float(), + delta_softplus=True, + return_last_state=True, + ) + y = rearrange(y, "b d l -> b l d") + attn_outputs = self.out_proj(y) + return attn_outputs, conv_state, last_state + + def step(self, hidden_states, conv_state, ssm_state): + _xz = self.in_proj(hidden_states) + _x, _z = _xz.chunk(2, dim=-1) # (B D) + conv_state_new = torch.cat([conv_state, _x.transpose(1,2)], dim=-1) + conv_out = causal_conv1d_fn( + x=conv_state_new, + weight=self.conv1d.weight.view(self.conv1d.weight.size(0), self.conv1d.weight.size(2)), + bias=self.conv1d.bias, + activation=self.activation + ) + conv_state = conv_state_new[:, :, 1:] + bsz, seqlen, dim = hidden_states.shape + output_tensor = torch.zeros( + (bsz, seqlen, dim), + device=hidden_states.device, + dtype=hidden_states.dtype + ) + for i in range(0, bsz): + x = conv_out[i:i+1,:,-1] + z = _z[i:i+1, -1, :] + x_db = self.x_proj(x) + dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1) + dt = F.linear(dt, self.dt_proj.weight) + y = selective_state_update( + ssm_state[i:i+1,:,:], x, dt, self.negA, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True + ) + out = self.out_proj(y) + output_tensor[i] = out + + return output_tensor, conv_state, ssm_state + + + +class ResidualBlock(nn.Module): + def __init__(self, layer_id, config, weights): + super().__init__() + self.mamba_block = MambaBlock(prefix=f"{layer_id}.mixer", config=config, weights=weights) + self.layer_norm = FastRMSNorm.load(prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] = None, + inference_params: Optional[Any] = None, + ): + residual = (hidden_states + residual) if residual is not None else hidden_states + shape = residual.shape + hidden_states, _ = self.layer_norm(residual.view(-1, shape[-1])) + hidden_states, conv_state, last_ssm_state = self.mamba_block(hidden_states.view(*shape), inference_params) + return hidden_states, residual, conv_state, last_ssm_state + +class MambaModel(nn.Module): + def __init__(self, config, weights): + super().__init__() + prefix = "backbone" + self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) + self.blocks = nn.ModuleList( + [ResidualBlock(f"{prefix}.layers.{i}", config, weights) for i in range(config.n_layer)] + ) + self.norm_f = FastRMSNorm.load(f"{prefix}.norm_f", weights, eps=config.layer_norm_epsilon) + self.lm_head = FastLinear.load(config, f"{prefix}.embedding", weights, bias=False) + self.config = config + + def forward(self, input_ids: torch.Tensor, inference_params=None, residual=None) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: + hidden_states = self.embed_tokens(input_ids) + for block in self.blocks: + hidden_states, residual, conv_state, ssm_state = block(hidden_states, residual, inference_params) + inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = (conv_state, ssm_state) + + hidden_states = hidden_states + residual if residual is not None else hidden_states + hidden_states, _ = self.norm_f(hidden_states.view(-1, hidden_states.size(-1))) + hidden_states = hidden_states.view(residual.shape) + logits = self.lm_head(hidden_states) + + # update the offset for the next inference using these params + inference_params.seqlen_offset += input_ids.size(1) + return logits, input_ids, inference_params \ No newline at end of file diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py new file mode 100644 index 00000000..c10910aa --- /dev/null +++ b/server/text_generation_server/models/mamba.py @@ -0,0 +1,656 @@ +import torch +import torch.distributed +from transformers import AutoTokenizer, PreTrainedTokenizerBase +from typing import Optional +from text_generation_server.models.custom_modeling.mamba_modeling import ( + MambaConfig, +) +from text_generation_server.pb import generate_pb2 +from text_generation_server.utils import ( + initialize_torch_distributed, + weight_files, + Weights, +) +import time +from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel +from text_generation_server.models import Model +from typing import Any, List, Optional, Tuple, Type, Dict +from text_generation_server.models.types import ( + Batch, + Tokens, + Generation, + GeneratedText, +) +from text_generation_server.utils.tokens import batch_top_tokens, Sampling +from dataclasses import dataclass +from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling +from mamba_ssm.utils.generation import InferenceParams + +@dataclass +class MambaBatch(Batch): + batch_id: int + requests: List[generate_pb2.Request] + requests_idx_mapping: Dict[int, int] + + # Decoder values + input_ids: torch.Tensor + + # All tokens + all_input_ids: List[torch.Tensor] + + # Lengths of all generations present in the batch + input_lengths: List[int] + prefix_offsets: List[int] + read_offsets: List[int] + + # Generation helpers + next_token_choosers: List[NextTokenChooser] + stopping_criterias: List[StoppingCriteria] + top_n_tokens: List[int] + top_n_tokens_tensor: torch.Tensor + + # Metadata used for padding + max_input_length: int + padding_right_offset: int + + # Maximum number of tokens this batch will grow to + max_tokens: int + + # Past metadata + keys_head_dim_last: bool = True + + # Inference params + inference_params: Optional[Dict[str, Any]] = None + + def to_pb(self) -> generate_pb2.CachedBatch: + return generate_pb2.CachedBatch( + id=self.batch_id, + request_ids=[r.id for r in self.requests], + size=len(self), + max_tokens=self.max_tokens, + ) + + @classmethod + def from_pb( + cls, + pb: generate_pb2.Batch, + tokenizer: PreTrainedTokenizerBase, + dtype: torch.dtype, + device: torch.device, + ) -> "MambaBatch": + inputs = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + prefix_offsets = [] + read_offsets = [] + requests_idx_mapping = {} + + # Parse batch + max_truncation = 0 + padding_right_offset = 0 + max_decode_tokens = 0 + for i, r in enumerate(pb.requests): + requests_idx_mapping[r.id] = i + inputs.append(r.inputs) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + stopping_criteria = StoppingCriteria.from_pb( + r.stopping_parameters, tokenizer + ) + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(r.top_n_tokens) + max_truncation = max(max_truncation, r.truncate) + max_decode_tokens += stopping_criteria.max_new_tokens + padding_right_offset = max( + padding_right_offset, stopping_criteria.max_new_tokens + ) + + tokenized_inputs = tokenizer( + inputs, + return_tensors="pt", + padding=True, + return_token_type_ids=False, + truncation=True, + max_length=max_truncation, + ).to(device) + for _ in pb.requests: + input_len = tokenized_inputs["input_ids"].shape[1] + prefix_offsets.append(input_len - 5) + read_offsets.append(input_len) + + input_lengths = tokenized_inputs["attention_mask"].sum(1) + max_input_length = input_lengths.max() + input_ids = tokenized_inputs["input_ids"] + all_input_ids = tokenized_inputs["input_ids"].T.split(1, dim=1) + top_n_tokens_tensor = torch.tensor( + top_n_tokens, device=device, dtype=torch.int64 + ) + max_tokens = len(inputs) * (max_input_length + max_decode_tokens) + return cls( + batch_id=pb.id, + requests=pb.requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + # past_input_ids=None, + all_input_ids=list(all_input_ids), + input_lengths=input_lengths.tolist(), + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length.item(), + padding_right_offset=padding_right_offset, + max_tokens=max_tokens, + ) + + def filter(self, request_ids: List[int]) -> Optional["MambaBatch"]: + if len(request_ids) == 0: + raise ValueError("Batch must have at least one request") + if len(request_ids) == len(self): + return self + + keep_indices = [] + + # New values after filtering + requests_idx_mapping = {} + requests = [] + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + max_input_length = 0 + + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + + total_remaining_decode_tokens = 0 + new_padding_right_offset = 0 + + indices = [] + for i, request_id in enumerate(request_ids): + idx = self.requests_idx_mapping[request_id] + requests_idx_mapping[request_id] = i + keep_indices.append(idx) + + requests.append(self.requests[idx]) + prefix_offsets.append(self.prefix_offsets[idx]) + read_offsets.append(self.read_offsets[idx]) + all_input_ids.append(self.all_input_ids[idx]) + + request_input_length = self.input_lengths[idx] + input_lengths.append(request_input_length) + max_input_length = max(max_input_length, request_input_length) + indices.append(idx) + + next_token_choosers.append(self.next_token_choosers[idx]) + stopping_criteria = self.stopping_criterias[idx] + stopping_criterias.append(stopping_criteria) + top_n_tokens.append(self.top_n_tokens[idx]) + remaining_decode_tokens = ( + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + ) + total_remaining_decode_tokens += remaining_decode_tokens + new_padding_right_offset = max( + new_padding_right_offset, remaining_decode_tokens + ) + + # Apply indices to input_ids, attention mask, past key values and other items that need to be cached + input_ids = self.input_ids[keep_indices] + + top_n_tokens_tensor = self.top_n_tokens_tensor[keep_indices] + max_tokens = len(request_ids) * max_input_length + total_remaining_decode_tokens + + self.requests = requests + self.requests_idx_mapping = requests_idx_mapping + self.input_ids = input_ids + self.all_input_ids = all_input_ids + self.input_lengths = input_lengths + self.prefix_offsets = prefix_offsets + self.read_offsets = read_offsets + self.next_token_choosers = next_token_choosers + self.stopping_criterias = stopping_criterias + self.top_n_tokens = top_n_tokens + self.top_n_tokens_tensor = top_n_tokens_tensor + self.max_input_length = max_input_length + self.padding_right_offset = new_padding_right_offset + self.max_tokens = max_tokens + + # TODO + # Kept it simple by just updating the state, maybe updating the other CPU values is necessary. + key_value_memory_dict = {} + for i, (conv_state, ssm_state) in self.inference_params.key_value_memory_dict.items(): + key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices]) + self.inference_params.key_value_memory_dict = key_value_memory_dict + + return self + + @classmethod + def concatenate(cls, batches: List["MambaBatch"]) -> "MambaBatch": + # Used for padding + total_batch_size = 0 + max_input_length = 0 + padding_right_offset = 0 + for batch in batches: + total_batch_size += len(batch) + max_input_length = max(max_input_length, batch.max_input_length) + padding_right_offset = max(padding_right_offset, batch.padding_right_offset) + + # Batch attributes + requests = [] + requests_idx_mapping = {} + input_lengths = [] + prefix_offsets = [] + read_offsets = [] + all_input_ids = [] + next_token_choosers = [] + stopping_criterias = [] + top_n_tokens = [] + max_tokens = 0 + max_seqlen = 0 + batch_size = 0 + seqlen_offset = 0 + + # Batch tensors + input_ids = None + top_n_tokens_tensor = None + + # Used for slicing correctly inside the tensors + # Equivalent to a cumsum on batch sizes + start_index = 0 + for i, batch in enumerate(batches): + requests.extend(batch.requests) + input_lengths.extend(batch.input_lengths) + prefix_offsets.extend(batch.prefix_offsets) + read_offsets.extend(batch.read_offsets) + all_input_ids.extend(batch.all_input_ids) + next_token_choosers.extend(batch.next_token_choosers) + stopping_criterias.extend(batch.stopping_criterias) + top_n_tokens.extend(batch.top_n_tokens) + + if i == 0: + requests_idx_mapping = batch.requests_idx_mapping + else: + # We need to offset the mapping for each batch by the cumulative batch size + for k, v in batch.requests_idx_mapping.items(): + requests_idx_mapping[k] = v + start_index + + # Slicing end index for this batch + end_index = start_index + len(batch) + + # Create empty tensor + # input_ids is always of shape [batch_size, 1] + # We do not need to pad it + if input_ids is None: + input_ids = batch.input_ids.new_empty((total_batch_size, 1)) + # Copy to correct indices + input_ids[start_index:end_index] = batch.input_ids + + if top_n_tokens_tensor is None: + top_n_tokens_tensor = batches[0].top_n_tokens_tensor.new_zeros( + total_batch_size, + ) + top_n_tokens_tensor[start_index:end_index] = batch.top_n_tokens_tensor + + # Add eventual padding tokens that were added while concatenating + max_tokens += batch.max_tokens + ( + max_input_length - batch.max_input_length + ) * len(batch) + + max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen) + seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset) + batch_size += batch.inference_params.max_batch_size + + start_index = end_index + + + (_, d_model, d_conv) = batches[0].inference_params.key_value_memory_dict[0][0].shape + (_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape + n_blocks = len(batches[0].inference_params.key_value_memory_dict) + dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype + device = batches[0].inference_params.key_value_memory_dict[0][0].device + + key_value_memory_dict = {} + for i in range(n_blocks): + conv_state = torch.zeros( + batch_size, + d_model, + d_conv, + device=device, + dtype=dtype, + ) + ssm_state = torch.zeros( + batch_size, + d_model, + d_state, + device=device, + dtype=dtype, + ) + key_value_memory_dict[i] = (conv_state, ssm_state) + lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device) + + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_offset, + key_value_memory_dict=key_value_memory_dict, + lengths_per_sample=lengths_per_sample, + ) + + current_batch = 0 + for batch in batches: + for i in range(n_blocks): + conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i] + batch_size = batch.inference_params.max_batch_size + inference_params.key_value_memory_dict[i][0][current_batch:current_batch + batch_size] = conv_state + inference_params.key_value_memory_dict[i][1][current_batch:current_batch + batch_size] = ssm_state + inference_params.lengths_per_sample[current_batch: current_batch + batch_size] = batch.inference_params.lengths_per_sample + current_batch += batch_size + + return cls( + batch_id=batches[0].batch_id, + requests=requests, + requests_idx_mapping=requests_idx_mapping, + input_ids=input_ids, + all_input_ids=all_input_ids, + input_lengths=input_lengths, + prefix_offsets=prefix_offsets, + read_offsets=read_offsets, + next_token_choosers=next_token_choosers, + stopping_criterias=stopping_criterias, + top_n_tokens=top_n_tokens, + top_n_tokens_tensor=top_n_tokens_tensor, + max_input_length=max_input_length, + padding_right_offset=padding_right_offset, + keys_head_dim_last=batches[0].keys_head_dim_last, + max_tokens=max_tokens, + inference_params=inference_params + ) + + def __len__(self): + return len(self.requests) + +class Mamba(Model): + def __init__( + self, + model_id: str, + revision: Optional[str] = None, + quantize: Optional[str] = None, + dtype: Optional[torch.dtype] = None, + trust_remote_code: bool = False, + ): + self.process_group, _rank, _world_size = initialize_torch_distributed() + if torch.cuda.is_available(): + device = torch.device("cuda") + dtype = torch.float16 if dtype is None else dtype + else: + if quantize: + raise ValueError("quantization is not available on CPU") + + device = torch.device("cpu") + dtype = torch.float32 if dtype is None else dtype + + tokenizer = AutoTokenizer.from_pretrained( + "EleutherAI/gpt-neox-20b", + revision=revision, + padding_side="left", + truncation_side="left", + trust_remote_code=trust_remote_code, + ) + config = MambaConfig.from_pretrained( + model_id, revision=revision, trust_remote_code=trust_remote_code + ) + + tokenizer.bos_token_id = config.bos_token_id + tokenizer.eos_token_id = config.eos_token_id + tokenizer.pad_token = tokenizer.eos_token + + config.quantize = quantize + torch.distributed.barrier(group=self.process_group) + filenames = weight_files(model_id, revision=revision, extension=".safetensors") + weights = Weights(filenames, device, dtype, process_group=self.process_group) + model = MambaModel(config, weights) + torch.distributed.barrier(group=self.process_group) + super(Mamba, self).__init__( + model=model, + tokenizer=tokenizer, + requires_padding=True, + dtype=dtype, + device=device, + ) + + @property + def batch_type(self) -> Type[MambaBatch]: + return MambaBatch + + def warmup(self, batch) -> Optional[int]: + # TODO: implement warmup for Mamba if needed + return None + + def forward( + self, + input_ids: torch.Tensor, + past: Optional[List[torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + return self.model( + input_ids, + past=past, + ) + + def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: + start = time.time_ns() + input_ids = batch.input_ids # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids + + batch_size = input_ids.shape[0] + max_seqlen = input_ids.shape[1] + dtype = input_ids.dtype + + # Inference params + seqlen_og = 0 + inf_cache = {} + lengths_per_sample = torch.ones(batch_size, dtype=torch.int32, device=input_ids.device) * max_seqlen + + if batch.inference_params is None: + inference_params = InferenceParams( + max_seqlen=max_seqlen, + max_batch_size=batch_size, + seqlen_offset=seqlen_og, + key_value_memory_dict=inf_cache, + lengths_per_sample=lengths_per_sample, + ) + + # Allocate inference cache + for res_block in self.model.blocks: + block = res_block.mamba_block + conv_state = torch.zeros( + batch_size, + self.model.config.d_model * self.model.config.expand, + self.model.config.d_conv, + device=block.conv1d.weight.device, + dtype=block.conv1d.weight.dtype, + ) + ssm_state = torch.zeros( + batch_size, + self.model.config.d_model * self.model.config.expand, + self.model.config.d_state, + device=block.dt_proj.weight.device, + dtype=block.dt_proj.weight.dtype, + ) + inference_params.key_value_memory_dict[block.layer_idx] = (conv_state, ssm_state) + batch.inference_params = inference_params + + # Forward pass + logits, past_input_ids, new_inference_params = self.model(input_ids, batch.inference_params) + + batch.inference_params = new_inference_params + # Results + generations: List[Generation] = [] + stopped = True + + # Speculation is not active for causal + accepted_ids = torch.ones_like(batch.input_ids)[:, 0] + batch_top_token_ids, batch_top_token_logprobs = batch_top_tokens( + batch.top_n_tokens, + batch.top_n_tokens_tensor, + torch.log_softmax(logits[:, -1], -1), + accepted_ids, + ) + + start_decode = time.time_ns() + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + batch.top_n_tokens, + batch_top_token_ids, + batch_top_token_logprobs, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + top_n_tokens, + top_token_ids, + top_token_logprobs, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser( + all_input_ids.view(1, -1), logits[-1:, :] + ) + + # Append next token to all tokens + all_input_ids = torch.cat([all_input_ids, next_token_id]) + new_input_length = input_length + 1 + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + next_token_id_squeezed = next_token_id.squeeze() + next_token_text, prefix_offset, read_offset = self.decode_token( + all_input_ids[:, 0], prefix_offset, read_offset + ) + + # Evaluate stopping criteria + stop, reason = stopping_criteria( + next_token_id_squeezed, + next_token_text, + ) + + if not stop: + stopped = False + + # Shard generations + # All generations will be appended in the rust sharded client + if i % self.world_size == self.rank: + if stop: + # Decode generated tokens + output_text, _, _ = self.decode_token( + all_input_ids[:, 0], + prefix_offset=len(all_input_ids) + - stopping_criteria.current_tokens + - 1, + read_offset=len(all_input_ids) - stopping_criteria.current_tokens, + skip_special_tokens=True, + ) + # Get seed + if isinstance(next_token_chooser.choice, Sampling): + seed = next_token_chooser.choice.seed + else: + seed = None + + generated_text = GeneratedText( + output_text, stopping_criteria.current_tokens, reason, seed + ) + else: + generated_text = None + + if stopping_criteria.current_tokens == 1 and request.prefill_logprobs: + # Remove generated token to only have prefill and add nan for first prompt token + prefill_logprobs = [float("nan")] + torch.log_softmax( + logits, -1 + ).gather(1, all_input_ids[1:]).squeeze(1)[ + -new_input_length:-1 + ].tolist() + prefill_token_ids = all_input_ids[-new_input_length:-1] + prefill_texts = self.tokenizer.batch_decode( + prefill_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + prefill_tokens = Tokens( + prefill_token_ids, + prefill_logprobs, + prefill_texts, + is_special=[], + ) + else: + prefill_tokens = None + + if top_n_tokens > 0: + toptoken_texts = self.tokenizer.batch_decode( + top_token_ids, + clean_up_tokenization_spaces=False, + skip_special_tokens=False, + ) + special_toptokens = [ + token_id in self.all_special_ids for token_id in top_token_ids + ] + top_tokens = Tokens( + top_token_ids, + top_token_logprobs, + toptoken_texts, + special_toptokens, + ) + else: + top_tokens = None + + generation = Generation( + request.id, + prefill_tokens, + Tokens( + [next_token_id_squeezed], + [next_token_logprob], + [next_token_text], + [next_token_id_squeezed.item() in self.all_special_ids], + ), + generated_text, + top_tokens, + ) + + generations.append(generation) + + # Update values + batch.input_ids[i, 0] = next_token_id + batch.all_input_ids[i] = all_input_ids + batch.input_lengths[i] = new_input_length + batch.prefix_offsets[i] = prefix_offset + batch.read_offsets[i] = read_offset + batch.max_input_length = max(batch.max_input_length, new_input_length) + + # We finished all generations in the batch; there is no next batch + if stopped: + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, None, (forward_ns, decode_ns) + + # Slice unused values from prefill + batch.input_ids = batch.input_ids[:, :1] + + forward_ns = start_decode - start + decode_ns = time.time_ns() - start_decode + return generations, batch, (forward_ns, decode_ns)