Improving mamba runtime by using updates (#1552)

- Move float16 to bfloat16, which has less imprecisions (load test are
  failing with the update kernels + f16, all working under bf16).

  Another note, is that we are not respecting the layer norm in f32
  defined in the configuration (this is OK in my book, but that could
  impact the f16 precision)

- Moved to update kernels. Triton overhead is super high, removed by
  switching to cuda graphs works great (update cuda graph is available
  in TRT-LLM if needed, seems *exactly* like the regular ssm kernel.

- Moved inference_params struct in order to make only 2 tensors, to
  reduce the overhead of copying back and forth to the cuda graphs.

- Left over overhead seems entirely in the tokenization bit. (Still 4
  copies are paid before launching the graph)


# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->
This commit is contained in:
Nicolas Patry 2024-02-14 09:54:10 +01:00 committed by GitHub
parent 7671a419a0
commit d6b0fb9e25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 300 additions and 284 deletions

View File

@ -8,61 +8,61 @@
"tokens": [ "tokens": [
{ {
"id": 187, "id": 187,
"logprob": -0.3552246, "logprob": -0.37890625,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 187, "id": 187,
"logprob": -0.38378906, "logprob": -0.26953125,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 30763, "id": 30763,
"logprob": -1.140625, "logprob": -1.1953125,
"special": false, "special": false,
"text": "Deep" "text": "Deep"
}, },
{ {
"id": 4715, "id": 4715,
"logprob": -0.5551758, "logprob": -0.53515625,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 310, "id": 310,
"logprob": -0.59033203, "logprob": -0.625,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 247, "id": 247,
"logprob": -0.70654297, "logprob": -0.6796875,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 747, "id": 747,
"logprob": -2.0410156, "logprob": -2.0,
"special": false, "special": false,
"text": " new" "text": " new"
}, },
{ {
"id": 1511, "id": 1511,
"logprob": -2.3789062, "logprob": -2.3125,
"special": false, "special": false,
"text": " type" "text": " type"
}, },
{ {
"id": 273, "id": 273,
"logprob": -0.0026435852, "logprob": -0.0028533936,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5145, "id": 5145,
"logprob": -1.2841797, "logprob": -1.265625,
"special": false, "special": false,
"text": " machine" "text": " machine"
} }

View File

@ -11,22 +11,22 @@
}, },
{ {
"id": 13, "id": 13,
"logprob": -2.5234375, "logprob": -2.734375,
"text": "," "text": ","
}, },
{ {
"id": 8862, "id": 8862,
"logprob": -3.4433594, "logprob": -3.6875,
"text": " yellow" "text": " yellow"
}, },
{ {
"id": 13, "id": 13,
"logprob": -0.43017578, "logprob": -0.40234375,
"text": "," "text": ","
}, },
{ {
"id": 209, "id": 209,
"logprob": -8.21875, "logprob": -8.25,
"text": " " "text": " "
} }
], ],
@ -40,60 +40,60 @@
}, },
{ {
"id": 395, "id": 395,
"logprob": -0.46411133, "logprob": -0.3125,
"special": false, "special": false,
"text": "and" "text": "and"
}, },
{ {
"id": 13735, "id": 4797,
"logprob": -2.1132812,
"special": false,
"text": " orange"
},
{
"id": 313,
"logprob": -1.2128906,
"special": false,
"text": " ("
},
{
"id": 249,
"logprob": -2.3671875,
"special": false,
"text": "in"
},
{
"id": 253,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " the" "text": " blue"
}, },
{ {
"id": 1340, "id": 9830,
"logprob": -1.640625, "logprob": -1.65625,
"special": false, "special": false,
"text": " order" "text": " colors"
}, },
{ {
"id": 597, "id": 15,
"logprob": -0.5488281,
"special": false,
"text": " they"
},
{
"id": 3176,
"logprob": -0.48608398,
"special": false,
"text": " appear"
},
{
"id": 275,
"logprob": 0.0, "logprob": 0.0,
"special": false, "special": false,
"text": " in" "text": "."
},
{
"id": 329,
"logprob": -2.4375,
"special": false,
"text": " A"
},
{
"id": 1180,
"logprob": -1.953125,
"special": false,
"text": " number"
},
{
"id": 273,
"logprob": 0.0,
"special": false,
"text": " of"
},
{
"id": 1027,
"logprob": -1.5546875,
"special": false,
"text": " different"
},
{
"id": 3295,
"logprob": -0.97265625,
"special": false,
"text": " color"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "blue, red, yellow, \nand orange (in the order they appear in" "generated_text": "blue, red, yellow, \nand blue colors. A number of different color"
} }

View File

@ -12,22 +12,22 @@
}, },
{ {
"id": 310, "id": 310,
"logprob": -0.8125, "logprob": -0.83984375,
"text": " is" "text": " is"
}, },
{ {
"id": 18147, "id": 18147,
"logprob": -12.828125, "logprob": -12.8125,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 20727, "id": 20727,
"logprob": -3.0, "logprob": -2.84375,
"text": " Learning" "text": " Learning"
}, },
{ {
"id": 32, "id": 32,
"logprob": -1.1484375, "logprob": -1.25,
"text": "?" "text": "?"
} }
], ],
@ -35,61 +35,61 @@
"tokens": [ "tokens": [
{ {
"id": 187, "id": 187,
"logprob": -0.3552246, "logprob": -0.37890625,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 187, "id": 187,
"logprob": -0.38378906, "logprob": -0.4296875,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 30763, "id": 30763,
"logprob": -1.1279297, "logprob": -1.078125,
"special": false, "special": false,
"text": "Deep" "text": "Deep"
}, },
{ {
"id": 4715, "id": 4715,
"logprob": -0.5595703, "logprob": -0.515625,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 310, "id": 310,
"logprob": -0.60253906, "logprob": -0.6015625,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 247, "id": 247,
"logprob": -0.7050781, "logprob": -0.65625,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 747, "id": 747,
"logprob": -2.0488281, "logprob": -2.109375,
"special": false, "special": false,
"text": " new" "text": " new"
}, },
{ {
"id": 1511, "id": 1511,
"logprob": -2.3808594, "logprob": -2.328125,
"special": false, "special": false,
"text": " type" "text": " type"
}, },
{ {
"id": 273, "id": 273,
"logprob": -0.0026416779, "logprob": -0.0032653809,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5145, "id": 5145,
"logprob": -1.2851562, "logprob": -1.28125,
"special": false, "special": false,
"text": " machine" "text": " machine"
} }
@ -111,22 +111,22 @@
}, },
{ {
"id": 310, "id": 310,
"logprob": -0.78027344, "logprob": -0.80078125,
"text": " is" "text": " is"
}, },
{ {
"id": 18147, "id": 18147,
"logprob": -12.8203125, "logprob": -13.25,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 20727, "id": 20727,
"logprob": -2.9902344, "logprob": -2.828125,
"text": " Learning" "text": " Learning"
}, },
{ {
"id": 32, "id": 32,
"logprob": -1.1523438, "logprob": -1.1953125,
"text": "?" "text": "?"
} }
], ],
@ -134,61 +134,61 @@
"tokens": [ "tokens": [
{ {
"id": 187, "id": 187,
"logprob": -0.35351562, "logprob": -0.296875,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 187, "id": 187,
"logprob": -0.38256836, "logprob": -0.3359375,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 30763, "id": 30763,
"logprob": -1.1269531, "logprob": -1.2578125,
"special": false, "special": false,
"text": "Deep" "text": "Deep"
}, },
{ {
"id": 4715, "id": 4715,
"logprob": -0.54541016, "logprob": -0.5546875,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 310, "id": 310,
"logprob": -0.59765625, "logprob": -0.62890625,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 247, "id": 247,
"logprob": -0.7001953, "logprob": -0.64453125,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 747, "id": 747,
"logprob": -2.0585938, "logprob": -2.078125,
"special": false, "special": false,
"text": " new" "text": " new"
}, },
{ {
"id": 1511, "id": 1511,
"logprob": -2.3789062, "logprob": -2.28125,
"special": false, "special": false,
"text": " type" "text": " type"
}, },
{ {
"id": 273, "id": 273,
"logprob": -0.0027446747, "logprob": -0.0030670166,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5145, "id": 5145,
"logprob": -1.2851562, "logprob": -1.3125,
"special": false, "special": false,
"text": " machine" "text": " machine"
} }
@ -210,22 +210,22 @@
}, },
{ {
"id": 310, "id": 310,
"logprob": -0.78027344, "logprob": -0.80078125,
"text": " is" "text": " is"
}, },
{ {
"id": 18147, "id": 18147,
"logprob": -12.8203125, "logprob": -13.25,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 20727, "id": 20727,
"logprob": -2.9902344, "logprob": -2.828125,
"text": " Learning" "text": " Learning"
}, },
{ {
"id": 32, "id": 32,
"logprob": -1.1523438, "logprob": -1.1953125,
"text": "?" "text": "?"
} }
], ],
@ -233,61 +233,61 @@
"tokens": [ "tokens": [
{ {
"id": 187, "id": 187,
"logprob": -0.35351562, "logprob": -0.296875,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 187, "id": 187,
"logprob": -0.38256836, "logprob": -0.3359375,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 30763, "id": 30763,
"logprob": -1.1269531, "logprob": -1.2578125,
"special": false, "special": false,
"text": "Deep" "text": "Deep"
}, },
{ {
"id": 4715, "id": 4715,
"logprob": -0.54541016, "logprob": -0.5546875,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 310, "id": 310,
"logprob": -0.59765625, "logprob": -0.62890625,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 247, "id": 247,
"logprob": -0.7001953, "logprob": -0.64453125,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 747, "id": 747,
"logprob": -2.0585938, "logprob": -2.078125,
"special": false, "special": false,
"text": " new" "text": " new"
}, },
{ {
"id": 1511, "id": 1511,
"logprob": -2.3789062, "logprob": -2.28125,
"special": false, "special": false,
"text": " type" "text": " type"
}, },
{ {
"id": 273, "id": 273,
"logprob": -0.0027446747, "logprob": -0.0030670166,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5145, "id": 5145,
"logprob": -1.2851562, "logprob": -1.3125,
"special": false, "special": false,
"text": " machine" "text": " machine"
} }
@ -309,22 +309,22 @@
}, },
{ {
"id": 310, "id": 310,
"logprob": -0.78027344, "logprob": -0.80078125,
"text": " is" "text": " is"
}, },
{ {
"id": 18147, "id": 18147,
"logprob": -12.8203125, "logprob": -13.25,
"text": " Deep" "text": " Deep"
}, },
{ {
"id": 20727, "id": 20727,
"logprob": -2.9902344, "logprob": -2.828125,
"text": " Learning" "text": " Learning"
}, },
{ {
"id": 32, "id": 32,
"logprob": -1.1523438, "logprob": -1.1953125,
"text": "?" "text": "?"
} }
], ],
@ -332,61 +332,61 @@
"tokens": [ "tokens": [
{ {
"id": 187, "id": 187,
"logprob": -0.35351562, "logprob": -0.296875,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 187, "id": 187,
"logprob": -0.38256836, "logprob": -0.3359375,
"special": false, "special": false,
"text": "\n" "text": "\n"
}, },
{ {
"id": 30763, "id": 30763,
"logprob": -1.1269531, "logprob": -1.2578125,
"special": false, "special": false,
"text": "Deep" "text": "Deep"
}, },
{ {
"id": 4715, "id": 4715,
"logprob": -0.54541016, "logprob": -0.5546875,
"special": false, "special": false,
"text": " learning" "text": " learning"
}, },
{ {
"id": 310, "id": 310,
"logprob": -0.59765625, "logprob": -0.62890625,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 247, "id": 247,
"logprob": -0.7001953, "logprob": -0.64453125,
"special": false, "special": false,
"text": " a" "text": " a"
}, },
{ {
"id": 747, "id": 747,
"logprob": -2.0585938, "logprob": -2.078125,
"special": false, "special": false,
"text": " new" "text": " new"
}, },
{ {
"id": 1511, "id": 1511,
"logprob": -2.3789062, "logprob": -2.28125,
"special": false, "special": false,
"text": " type" "text": " type"
}, },
{ {
"id": 273, "id": 273,
"logprob": -0.0027446747, "logprob": -0.0030670166,
"special": false, "special": false,
"text": " of" "text": " of"
}, },
{ {
"id": 5145, "id": 5145,
"logprob": -1.2851562, "logprob": -1.3125,
"special": false, "special": false,
"text": " machine" "text": " machine"
} }

View File

@ -47,14 +47,14 @@ async def test_mamba_all_params(fused_kernel_mamba, response_snapshot):
assert response.details.generated_tokens == 10 assert response.details.generated_tokens == 10
assert ( assert (
response.generated_text response.generated_text
== "blue, red, yellow, \nand orange (in the order they appear in" == "blue, red, yellow, \nand blue colors. A number of different color"
) )
assert response == response_snapshot assert response == response_snapshot
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.private @pytest.mark.private
async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): async def test_mamba_load(fused_kernel_mamba, generate_load, generous_response_snapshot):
responses = await generate_load( responses = await generate_load(
fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4 fused_kernel_mamba, "What is Deep Learning?", max_new_tokens=10, n=4
) )
@ -63,4 +63,4 @@ async def test_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
assert all([r.generated_text == responses[0].generated_text for r in responses]) assert all([r.generated_text == responses[0].generated_text for r in responses])
assert responses[0].generated_text == "\n\nDeep learning is a new type of machine" assert responses[0].generated_text == "\n\nDeep learning is a new type of machine"
assert responses == response_snapshot assert responses == generous_response_snapshot

View File

@ -3,7 +3,6 @@ import torch.distributed
from mamba_ssm.ops.triton.selective_state_update import selective_state_update from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.utils.generation import InferenceParams
from torch import nn from torch import nn
from typing import Optional, Tuple, Any from typing import Optional, Tuple, Any
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
@ -18,6 +17,17 @@ from text_generation_server.utils.layers import (
from einops import rearrange from einops import rearrange
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
import math import math
from dataclasses import dataclass
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
conv_states: torch.Tensor
ssm_states: torch.Tensor
seqlen_offset: int
class MambaConfig(PretrainedConfig): class MambaConfig(PretrainedConfig):
@ -56,9 +66,9 @@ class MambaConfig(PretrainedConfig):
class MambaBlock(nn.Module): class MambaBlock(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, layer_id):
super().__init__() super().__init__()
self.layer_idx = int(prefix.split(".")[2]) self.layer_id = layer_id
self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False) self.in_proj = FastLinear.load(config, f"{prefix}.in_proj", weights, bias=False)
self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False) self.x_proj = FastLinear.load(config, f"{prefix}.x_proj", weights, bias=False)
self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True) self.dt_proj = FastLinear.load(config, f"{prefix}.dt_proj", weights, bias=True)
@ -79,21 +89,20 @@ class MambaBlock(nn.Module):
# inference_params # inference_params
def forward(self, hidden_states: torch.Tensor, inference_params=None): def forward(self, hidden_states: torch.Tensor, inference_params=None):
_, seqlen, _ = hidden_states.shape
conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
if inference_params.seqlen_offset > 0: if inference_params.seqlen_offset > 0:
conv_state = inference_params.conv_states[self.layer_id]
ssm_state = inference_params.ssm_states[self.layer_id]
out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state) out, conv_state, ssm_state = self.step(hidden_states, conv_state, ssm_state)
return out, conv_state, ssm_state return out, conv_state, ssm_state
_, seqlen, _ = hidden_states.shape
projected_states = self.in_proj(hidden_states).transpose(1, 2) projected_states = self.in_proj(hidden_states).transpose(1, 2)
# assert projected_states.shape == [batch_size, 2 * dstate, seqlen], f"{projected_states.shape} [{batch_size}, {dstate}, {seqlen}]"
x, z = projected_states.chunk(2, dim=1) x, z = projected_states.chunk(2, dim=1)
conv_state = F.pad(x, (self.d_conv - seqlen, 0)) conv_state = F.pad(x, (self.d_conv - seqlen, 0))
x = causal_conv1d_fn( x = causal_conv1d_fn(
x=x, x=x,
weight=self.conv1d.weight.view( weight=self.conv1d.weight.squeeze(1),
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
),
bias=self.conv1d.bias, bias=self.conv1d.bias,
activation=self.activation, activation=self.activation,
) )
@ -126,56 +135,28 @@ class MambaBlock(nn.Module):
return attn_outputs, conv_state, last_state return attn_outputs, conv_state, last_state
def step(self, hidden_states, conv_state, ssm_state): def step(self, hidden_states, conv_state, ssm_state):
_xz = self.in_proj(hidden_states) xz = self.in_proj(hidden_states.squeeze(1))
_x, _z = _xz.chunk(2, dim=-1) # (B D) x, z = xz.chunk(2, dim=-1) # (B D)
conv_state_new = torch.cat([conv_state, _x.transpose(1, 2)], dim=-1) x = causal_conv1d_update(x, conv_state, self.conv1d.weight.squeeze(1), self.conv1d.bias, self.activation)
conv_out = causal_conv1d_fn( x_db = self.x_proj(x) # (B dt_rank+2*d_state)
x=conv_state_new, dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
weight=self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
),
bias=self.conv1d.bias,
activation=self.activation,
)
conv_state = conv_state_new[:, :, 1:]
bsz, seqlen, dim = hidden_states.shape
output_tensor = torch.zeros(
(bsz, seqlen, dim), device=hidden_states.device, dtype=hidden_states.dtype
)
for i in range(0, bsz):
x = conv_out[i : i + 1, :, -1]
z = _z[i : i + 1, -1, :]
x_db = self.x_proj(x)
dt, B, C = torch.split(
x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1
)
dt = F.linear(dt, self.dt_proj.weight) dt = F.linear(dt, self.dt_proj.weight)
A = self.negA
y = selective_state_update( y = selective_state_update(
ssm_state[i : i + 1, :, :], ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
x,
dt,
self.negA,
B,
C,
self.D,
z=z,
dt_bias=self.dt_proj.bias,
dt_softplus=True,
) )
out = self.out_proj(y) out = self.out_proj(y)
output_tensor[i] = out return out.unsqueeze(1), conv_state.clone(), ssm_state.clone()
return output_tensor, conv_state, ssm_state
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, layer_id, config, weights): def __init__(self, prefix, config, weights, layer_id):
super().__init__() super().__init__()
self.mamba_block = MambaBlock( self.mamba_block = MambaBlock(
prefix=f"{layer_id}.mixer", config=config, weights=weights prefix=f"{prefix}.mixer", config=config, weights=weights, layer_id=layer_id
) )
self.layer_norm = FastRMSNorm.load( self.layer_norm = FastRMSNorm.load(
prefix=f"{layer_id}.norm", weights=weights, eps=config.layer_norm_epsilon prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_epsilon
) )
def forward( def forward(
@ -200,7 +181,7 @@ class MambaModel(nn.Module):
self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights) self.embed_tokens = TensorParallelEmbedding(f"{prefix}.embedding", weights)
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [
ResidualBlock(f"{prefix}.layers.{i}", config, weights) ResidualBlock(f"{prefix}.layers.{i}", config, weights, layer_id=i)
for i in range(config.n_layer) for i in range(config.n_layer)
] ]
) )
@ -216,14 +197,12 @@ class MambaModel(nn.Module):
self, input_ids: torch.Tensor, inference_params=None, residual=None self, input_ids: torch.Tensor, inference_params=None, residual=None
) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]: ) -> Tuple[torch.Tensor, torch.Tensor, InferenceParams]:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
for block in self.blocks: for i, block in enumerate(self.blocks):
hidden_states, residual, conv_state, ssm_state = block( hidden_states, residual, conv_state, ssm_state = block(
hidden_states, residual, inference_params hidden_states, residual, inference_params
) )
inference_params.key_value_memory_dict[block.mamba_block.layer_idx] = ( inference_params.conv_states[i].copy_(conv_state)
conv_state, inference_params.ssm_states[i].copy_(ssm_state)
ssm_state,
)
hidden_states = ( hidden_states = (
hidden_states + residual if residual is not None else hidden_states hidden_states + residual if residual is not None else hidden_states
@ -234,4 +213,4 @@ class MambaModel(nn.Module):
# update the offset for the next inference using these params # update the offset for the next inference using these params
inference_params.seqlen_offset += input_ids.size(1) inference_params.seqlen_offset += input_ids.size(1)
return logits, input_ids, inference_params return logits

View File

@ -28,12 +28,12 @@ from text_generation_server.models.cache_manager import (
BLOCK_SIZE, BLOCK_SIZE,
) )
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.models.globals import MEM_POOL
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
from text_generation_server.utils.dist import MEMORY_FRACTION from text_generation_server.utils.dist import MEMORY_FRACTION
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
MEM_POOL = torch.cuda.graph_pool_handle()
@dataclass @dataclass

View File

@ -0,0 +1,3 @@
import torch
MEM_POOL = torch.cuda.graph_pool_handle()

View File

@ -2,17 +2,20 @@ import torch
import torch.distributed import torch.distributed
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer, PreTrainedTokenizerBase
from typing import Optional from typing import Optional
import os
from text_generation_server.models.custom_modeling.mamba_modeling import ( from text_generation_server.models.custom_modeling.mamba_modeling import (
MambaConfig, MambaConfig,
) )
from loguru import logger
from text_generation_server.pb import generate_pb2 from text_generation_server.pb import generate_pb2
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.models.globals import MEM_POOL
import time import time
from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel from text_generation_server.models.custom_modeling.mamba_modeling import MambaModel, InferenceParams
from text_generation_server.models import Model from text_generation_server.models import Model
from typing import Any, List, Optional, Tuple, Type, Dict from typing import Any, List, Optional, Tuple, Type, Dict
from text_generation_server.models.types import ( from text_generation_server.models.types import (
@ -24,7 +27,34 @@ from text_generation_server.models.types import (
from text_generation_server.utils.tokens import batch_top_tokens, Sampling from text_generation_server.utils.tokens import batch_top_tokens, Sampling
from dataclasses import dataclass from dataclasses import dataclass
from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling from text_generation_server.utils import NextTokenChooser, StoppingCriteria, Sampling
from mamba_ssm.utils.generation import InferenceParams
def new_inference_params(n_blocks: int, batch_size: int, d_inner: int, d_conv: int, d_state: int, seqlen_offset: int, dtype: torch.dtype, device: torch.device):
max_seqlen = 0
conv_states = torch.zeros(
(n_blocks,
batch_size,
d_inner,
d_conv,),
device=device,
dtype=dtype,
)
ssm_states = torch.zeros(
(n_blocks,
batch_size,
d_inner,
d_state,),
device=device,
dtype=dtype,
)
inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=seqlen_offset,
conv_states=conv_states,
ssm_states=ssm_states,
)
return inference_params
@dataclass @dataclass
@ -221,14 +251,8 @@ class MambaBatch(Batch):
# TODO # TODO
# Kept it simple by just updating the state, maybe updating the other CPU values is necessary. # Kept it simple by just updating the state, maybe updating the other CPU values is necessary.
key_value_memory_dict = {} self.inference_params.conv_states = self.inference_params.conv_states[:, indices]
for i, ( self.inference_params.ssm_states = self.inference_params.ssm_states[:, indices]
conv_state,
ssm_state,
) in self.inference_params.key_value_memory_dict.items():
key_value_memory_dict[i] = (conv_state[indices], ssm_state[indices])
self.inference_params.key_value_memory_dict = key_value_memory_dict
return self return self
@classmethod @classmethod
@ -254,9 +278,16 @@ class MambaBatch(Batch):
top_n_tokens = [] top_n_tokens = []
max_tokens = 0 max_tokens = 0
max_seqlen = 0 max_seqlen = 0
batch_size = 0
seqlen_offset = 0 seqlen_offset = 0
(n_blocks, _, d_inner, d_conv) = (
batches[0].inference_params.conv_states.shape
)
(_, _, _, d_state) = batches[0].inference_params.ssm_states.shape
dtype = batches[0].inference_params.conv_states.dtype
device = batches[0].inference_params.conv_states.device
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=total_batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=device, dtype=dtype)
# Batch tensors # Batch tensors
input_ids = None input_ids = None
top_n_tokens_tensor = None top_n_tokens_tensor = None
@ -303,63 +334,16 @@ class MambaBatch(Batch):
max_input_length - batch.max_input_length max_input_length - batch.max_input_length
) * len(batch) ) * len(batch)
max_seqlen = max(max_seqlen, batch.inference_params.max_seqlen) inference_params.max_seqlen = max(inference_params.max_seqlen, batch.inference_params.max_seqlen)
seqlen_offset = max(seqlen_offset, batch.inference_params.seqlen_offset) assert batch.inference_params.seqlen_offset != 0, "Invalid seqlen offset"
batch_size += batch.inference_params.max_batch_size inference_params.seqlen_offset = max(inference_params.seqlen_offset, batch.inference_params.seqlen_offset)
inference_params.conv_states[:, start_index:end_index] = batch.inference_params.conv_states
inference_params.ssm_states[:, start_index:end_index] = batch.inference_params.ssm_states
start_index = end_index start_index = end_index
(_, d_model, d_conv) = (
batches[0].inference_params.key_value_memory_dict[0][0].shape
)
(_, _, d_state) = batches[0].inference_params.key_value_memory_dict[0][1].shape
n_blocks = len(batches[0].inference_params.key_value_memory_dict)
dtype = batches[0].inference_params.key_value_memory_dict[0][0].dtype
device = batches[0].inference_params.key_value_memory_dict[0][0].device
key_value_memory_dict = {}
for i in range(n_blocks):
conv_state = torch.zeros(
batch_size,
d_model,
d_conv,
device=device,
dtype=dtype,
)
ssm_state = torch.zeros(
batch_size,
d_model,
d_state,
device=device,
dtype=dtype,
)
key_value_memory_dict[i] = (conv_state, ssm_state)
lengths_per_sample = torch.zeros(batch_size, dtype=torch.int32, device=device)
inference_params = InferenceParams(
max_seqlen=max_seqlen,
max_batch_size=batch_size,
seqlen_offset=seqlen_offset,
key_value_memory_dict=key_value_memory_dict,
lengths_per_sample=lengths_per_sample,
)
current_batch = 0
for batch in batches:
for i in range(n_blocks):
conv_state, ssm_state = batch.inference_params.key_value_memory_dict[i]
batch_size = batch.inference_params.max_batch_size
inference_params.key_value_memory_dict[i][0][
current_batch : current_batch + batch_size
] = conv_state
inference_params.key_value_memory_dict[i][1][
current_batch : current_batch + batch_size
] = ssm_state
inference_params.lengths_per_sample[
current_batch : current_batch + batch_size
] = batch.inference_params.lengths_per_sample
current_batch += batch_size
return cls( return cls(
batch_id=batches[0].batch_id, batch_id=batches[0].batch_id,
requests=requests, requests=requests,
@ -394,9 +378,13 @@ class Mamba(Model):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
self.process_group, _rank, _world_size = initialize_torch_distributed() self.process_group, _rank, _world_size = initialize_torch_distributed()
self.cuda_graphs = {}
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
dtype = torch.float16 if dtype is None else dtype # Bf16 is important. In f16 accumulations in the matmul are causing
# differences while the server is under load.
# This is detectable by the integration load test
dtype = torch.bfloat16 if dtype is None else dtype
else: else:
if quantize: if quantize:
raise ValueError("quantization is not available on CPU") raise ValueError("quantization is not available on CPU")
@ -439,74 +427,120 @@ class Mamba(Model):
def warmup(self, batch) -> Optional[int]: def warmup(self, batch) -> Optional[int]:
# TODO: implement warmup for Mamba if needed # TODO: implement warmup for Mamba if needed
if os.getenv("ENABLE_CUDA_GRAPHS", "False") == "True":
if self.speculate is None or self.speculate == 0:
try:
logger.info("Experimental support for Cuda Graphs is enabled")
# Warmup cuda graphs
for bs in [1, 2, 4] + [8 * i for i in range(1, 9)]:
self.cuda_graph_warmup(bs)
except Exception:
logger.exception(f"Decode cuda graph warmup failed")
return None return None
def cuda_graph_warmup(self, batch_size: int):
input_ids = torch.zeros((batch_size, 1), dtype=torch.int64, device=self.device)
n_blocks = len(self.model.blocks)
d_state = self.model.config.d_state
d_conv = self.model.config.d_conv
# Inner takes the expand multiplication
d_inner = self.model.config.d_inner
# Important seqlen_offset to go through the update mecanism with the state
seqlen_offset = 1
inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype)
graph = torch.cuda.CUDAGraph()
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
input_ids=input_ids,
inference_params=inference_params
)
torch.cuda.synchronize()
with torch.cuda.graph(graph, pool=MEM_POOL):
logits = self.model.forward(
input_ids=input_ids,
inference_params=inference_params
)
torch.cuda.synchronize()
graph_dict = {
"input_ids": input_ids,
"inference_params": inference_params,
"graph": graph,
"logits": logits
}
self.cuda_graphs[batch_size] = graph_dict
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
past: Optional[List[torch.Tensor]] = None, inference_params: Any
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
bs = input_ids.shape[0]
padded_bs = bs
if bs == 3:
padded_bs = 4
elif 3 < bs <= 8:
padded_bs = 8
elif bs > 8:
padded_bs = (bs + 7) // 8 * 8
# Try to find an associated cuda graph
cuda_graph = self.cuda_graphs.get(padded_bs, None)
is_prefill = inference_params is None or inference_params.seqlen_offset == 0
if is_prefill or cuda_graph is None:
return self.model( return self.model(
input_ids, input_ids,
past=past, inference_params=inference_params,
) )
# Copy inputs to the static inputs of the cuda graph
# Static inputs are potentially padded
cuda_graph["input_ids"][: bs] = input_ids
cuda_graph["inference_params"].conv_states[:, : bs] = inference_params.conv_states
cuda_graph["inference_params"].ssm_states[:, : bs] = inference_params.ssm_states
# Replay the graph
cuda_graph["graph"].replay()
inference_params.conv_states.copy_(cuda_graph["inference_params"].conv_states[:, :bs])
inference_params.ssm_states.copy_(cuda_graph["inference_params"].ssm_states[:, :bs])
# Slice output to the correct shape
return cuda_graph["logits"][:bs]
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns() start = time.time_ns()
input_ids = ( input_ids = (
batch.input_ids batch.input_ids
) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids ) # batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
batch_size = input_ids.shape[0] batch_size, max_seqlen = input_ids.shape
max_seqlen = input_ids.shape[1]
dtype = input_ids.dtype
# Inference params # Inference params
seqlen_og = 0
inf_cache = {}
lengths_per_sample = (
torch.ones(batch_size, dtype=torch.int32, device=input_ids.device)
* max_seqlen
)
if batch.inference_params is None: if batch.inference_params is None:
inference_params = InferenceParams( # 0 is important here
max_seqlen=max_seqlen, seqlen_offset = 0
max_batch_size=batch_size, n_blocks = len(self.model.blocks)
seqlen_offset=seqlen_og, d_state = self.model.config.d_state
key_value_memory_dict=inf_cache, d_conv = self.model.config.d_conv
lengths_per_sample=lengths_per_sample, d_inner = self.model.config.d_inner
) inference_params = new_inference_params(n_blocks=n_blocks, batch_size=batch_size, d_state=d_state, d_conv=d_conv, d_inner=d_inner, seqlen_offset=seqlen_offset, device=self.device, dtype=self.dtype)
# Allocate inference cache
for res_block in self.model.blocks:
block = res_block.mamba_block
conv_state = torch.zeros(
batch_size,
self.model.config.d_model * self.model.config.expand,
self.model.config.d_conv,
device=block.conv1d.weight.device,
dtype=block.conv1d.weight.dtype,
)
ssm_state = torch.zeros(
batch_size,
self.model.config.d_model * self.model.config.expand,
self.model.config.d_state,
device=block.dt_proj.weight.device,
dtype=block.dt_proj.weight.dtype,
)
inference_params.key_value_memory_dict[block.layer_idx] = (
conv_state,
ssm_state,
)
batch.inference_params = inference_params batch.inference_params = inference_params
# Forward pass # Forward pass
logits, past_input_ids, new_inference_params = self.model( logits = self.forward(
input_ids, batch.inference_params input_ids, inference_params=batch.inference_params
) )
batch.inference_params = new_inference_params
# batch.inference_params = new_inference_params
# Results # Results
generations: List[Generation] = [] generations: List[Generation] = []
stopped = True stopped = True