Improvments within mamba.

This commit is contained in:
Nicolas Patry 2024-01-31 10:28:58 +00:00
parent 5b6f9259c1
commit a3c45da0a4
4 changed files with 27 additions and 31 deletions

View File

@ -51,12 +51,12 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
# TODO: fix `Expected x0.dim() == 2 to be true, but got false.`
# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))`
# NOTE: the fast layer norm has strict requirements on the input shape
# @pytest.mark.asyncio
# @pytest.mark.private
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
@pytest.mark.asyncio
@pytest.mark.private
async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
# assert len(responses) == 4
# assert all([r.generated_text == responses[0].generated_text for r in responses])
assert len(responses) == 4
assert all([r.generated_text == responses[0].generated_text for r in responses])
# assert responses == response_snapshot
assert responses == response_snapshot

View File

@ -311,8 +311,8 @@ class CausalLMBatch(Batch):
end_index = start_index + len(batch)
# We only concatenate batches that did at least one step
if batch.past_key_values is None:
raise ValueError("only concatenate prefilled batches")
# if batch.past_key_values is None:
# raise ValueError("only concatenate prefilled batches")
# Create empty tensor
# input_ids is always of shape [batch_size, 1]

View File

@ -91,8 +91,9 @@ class ResidualBlock(nn.Module):
hidden_states: torch.Tensor,
):
residual = hidden_states
hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))
hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0))
shape = hidden_states.shape
hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1]))
hidden_states = residual + self.mamba_block(hidden_states.view(*shape))
return hidden_states
class MambaModel(nn.Module):
@ -114,5 +115,6 @@ class MambaModel(nn.Module):
for block in self.blocks:
hidden_states = block(hidden_states)
final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0))
return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids
shape = hidden_states.shape
final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1]))
return self.lm_head(final_hidden_states.view(*shape)), input_ids

View File

@ -30,23 +30,9 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
class MambaCausalLMBatch(CausalLMBatch):
past_input_ids: Optional[torch.Tensor]
pass
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.past_input_ids = None
@classmethod
def from_pb(
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch
class Mamba(Model):
@ -119,7 +105,7 @@ class Mamba(Model):
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
start = time.time_ns()
input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
input_ids = batch.input_ids
logits, past_input_ids = self.model(input_ids)[:2]
@ -151,6 +137,8 @@ class Mamba(Model):
)
# For each member of the batch
next_token_ids = []
kept_batch_ids = []
for i, (
request,
input_length,
@ -235,7 +223,8 @@ class Mamba(Model):
)
else:
prefill_tokens = None
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
next_token_ids.append(next_token_id)
kept_batch_ids.append(i)
if top_n_tokens > 0:
@ -278,13 +267,18 @@ class Mamba(Model):
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# Merge all new tokens
if next_token_ids:
next_token_ids = torch.cat(next_token_ids, dim=0)
past_input_ids = torch.cat([past_input_ids[kept_batch_ids], next_token_ids], dim=1)
# We finished all generations in the batch; there is no next batch
if stopped:
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode
return generations, None, (forward_ns, decode_ns)
batch.past_input_ids = past_input_ids
batch.input_ids = input_ids
forward_ns = start_decode - start
decode_ns = time.time_ns() - start_decode