diff --git a/integration-tests/models/test_fused_kernel_mamba.py b/integration-tests/models/test_fused_kernel_mamba.py index 0a449332..5d6aa81f 100644 --- a/integration-tests/models/test_fused_kernel_mamba.py +++ b/integration-tests/models/test_fused_kernel_mamba.py @@ -51,12 +51,12 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh # TODO: fix `Expected x0.dim() == 2 to be true, but got false.` # 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))` # NOTE: the fast layer norm has strict requirements on the input shape -# @pytest.mark.asyncio -# @pytest.mark.private -# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): -# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4) +@pytest.mark.asyncio +@pytest.mark.private +async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot): + responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4) -# assert len(responses) == 4 -# assert all([r.generated_text == responses[0].generated_text for r in responses]) + assert len(responses) == 4 + assert all([r.generated_text == responses[0].generated_text for r in responses]) -# assert responses == response_snapshot + assert responses == response_snapshot diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index 7b10256c..39b3f49b 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -311,8 +311,8 @@ class CausalLMBatch(Batch): end_index = start_index + len(batch) # We only concatenate batches that did at least one step - if batch.past_key_values is None: - raise ValueError("only concatenate prefilled batches") + # if batch.past_key_values is None: + # raise ValueError("only concatenate prefilled batches") # Create empty tensor # input_ids is always of shape [batch_size, 1] diff --git a/server/text_generation_server/models/custom_modeling/mamba_modeling.py b/server/text_generation_server/models/custom_modeling/mamba_modeling.py index ca5f9765..42da669a 100644 --- a/server/text_generation_server/models/custom_modeling/mamba_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mamba_modeling.py @@ -91,8 +91,9 @@ class ResidualBlock(nn.Module): hidden_states: torch.Tensor, ): residual = hidden_states - hidden_states, _ = self.layer_norm(hidden_states.squeeze(0)) - hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0)) + shape = hidden_states.shape + hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1])) + hidden_states = residual + self.mamba_block(hidden_states.view(*shape)) return hidden_states class MambaModel(nn.Module): @@ -114,5 +115,6 @@ class MambaModel(nn.Module): for block in self.blocks: hidden_states = block(hidden_states) - final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0)) - return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids + shape = hidden_states.shape + final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1])) + return self.lm_head(final_hidden_states.view(*shape)), input_ids diff --git a/server/text_generation_server/models/mamba.py b/server/text_generation_server/models/mamba.py index 05a5b99e..13d77abd 100644 --- a/server/text_generation_server/models/mamba.py +++ b/server/text_generation_server/models/mamba.py @@ -30,23 +30,9 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling class MambaCausalLMBatch(CausalLMBatch): - past_input_ids: Optional[torch.Tensor] + pass - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.past_input_ids = None - @classmethod - def from_pb( - cls, - pb: generate_pb2.Batch, - tokenizer: PreTrainedTokenizerBase, - dtype: torch.dtype, - device: torch.device, - ) -> "CausalLMBatch": - batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device) - batch.keys_head_dim_last = False - return batch class Mamba(Model): @@ -119,7 +105,7 @@ class Mamba(Model): def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]: start = time.time_ns() - input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids + input_ids = batch.input_ids logits, past_input_ids = self.model(input_ids)[:2] @@ -151,6 +137,8 @@ class Mamba(Model): ) # For each member of the batch + next_token_ids = [] + kept_batch_ids = [] for i, ( request, input_length, @@ -235,7 +223,8 @@ class Mamba(Model): ) else: prefill_tokens = None - past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1) + next_token_ids.append(next_token_id) + kept_batch_ids.append(i) if top_n_tokens > 0: @@ -278,13 +267,18 @@ class Mamba(Model): batch.read_offsets[i] = read_offset batch.max_input_length = max(batch.max_input_length, new_input_length) + # Merge all new tokens + if next_token_ids: + next_token_ids = torch.cat(next_token_ids, dim=0) + past_input_ids = torch.cat([past_input_ids[kept_batch_ids], next_token_ids], dim=1) + # We finished all generations in the batch; there is no next batch if stopped: forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode return generations, None, (forward_ns, decode_ns) - batch.past_input_ids = past_input_ids + batch.input_ids = input_ids forward_ns = start_decode - start decode_ns = time.time_ns() - start_decode