Improvments within mamba.
This commit is contained in:
parent
5b6f9259c1
commit
a3c45da0a4
|
@ -51,12 +51,12 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
|
|||
# TODO: fix `Expected x0.dim() == 2 to be true, but got false.`
|
||||
# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))`
|
||||
# NOTE: the fast layer norm has strict requirements on the input shape
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.private
|
||||
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.private
|
||||
async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||
responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
|
||||
|
||||
# assert len(responses) == 4
|
||||
# assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
assert len(responses) == 4
|
||||
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||
|
||||
# assert responses == response_snapshot
|
||||
assert responses == response_snapshot
|
||||
|
|
|
@ -311,8 +311,8 @@ class CausalLMBatch(Batch):
|
|||
end_index = start_index + len(batch)
|
||||
|
||||
# We only concatenate batches that did at least one step
|
||||
if batch.past_key_values is None:
|
||||
raise ValueError("only concatenate prefilled batches")
|
||||
# if batch.past_key_values is None:
|
||||
# raise ValueError("only concatenate prefilled batches")
|
||||
|
||||
# Create empty tensor
|
||||
# input_ids is always of shape [batch_size, 1]
|
||||
|
|
|
@ -91,8 +91,9 @@ class ResidualBlock(nn.Module):
|
|||
hidden_states: torch.Tensor,
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))
|
||||
hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0))
|
||||
shape = hidden_states.shape
|
||||
hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1]))
|
||||
hidden_states = residual + self.mamba_block(hidden_states.view(*shape))
|
||||
return hidden_states
|
||||
|
||||
class MambaModel(nn.Module):
|
||||
|
@ -114,5 +115,6 @@ class MambaModel(nn.Module):
|
|||
for block in self.blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0))
|
||||
return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids
|
||||
shape = hidden_states.shape
|
||||
final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1]))
|
||||
return self.lm_head(final_hidden_states.view(*shape)), input_ids
|
||||
|
|
|
@ -30,23 +30,9 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
|||
|
||||
|
||||
class MambaCausalLMBatch(CausalLMBatch):
|
||||
past_input_ids: Optional[torch.Tensor]
|
||||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.past_input_ids = None
|
||||
|
||||
@classmethod
|
||||
def from_pb(
|
||||
cls,
|
||||
pb: generate_pb2.Batch,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> "CausalLMBatch":
|
||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
||||
batch.keys_head_dim_last = False
|
||||
return batch
|
||||
|
||||
|
||||
class Mamba(Model):
|
||||
|
@ -119,7 +105,7 @@ class Mamba(Model):
|
|||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||
start = time.time_ns()
|
||||
|
||||
input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
||||
input_ids = batch.input_ids
|
||||
|
||||
logits, past_input_ids = self.model(input_ids)[:2]
|
||||
|
||||
|
@ -151,6 +137,8 @@ class Mamba(Model):
|
|||
)
|
||||
|
||||
# For each member of the batch
|
||||
next_token_ids = []
|
||||
kept_batch_ids = []
|
||||
for i, (
|
||||
request,
|
||||
input_length,
|
||||
|
@ -235,7 +223,8 @@ class Mamba(Model):
|
|||
)
|
||||
else:
|
||||
prefill_tokens = None
|
||||
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
|
||||
next_token_ids.append(next_token_id)
|
||||
kept_batch_ids.append(i)
|
||||
|
||||
|
||||
if top_n_tokens > 0:
|
||||
|
@ -278,13 +267,18 @@ class Mamba(Model):
|
|||
batch.read_offsets[i] = read_offset
|
||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||
|
||||
# Merge all new tokens
|
||||
if next_token_ids:
|
||||
next_token_ids = torch.cat(next_token_ids, dim=0)
|
||||
past_input_ids = torch.cat([past_input_ids[kept_batch_ids], next_token_ids], dim=1)
|
||||
|
||||
# We finished all generations in the batch; there is no next batch
|
||||
if stopped:
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
return generations, None, (forward_ns, decode_ns)
|
||||
|
||||
batch.past_input_ids = past_input_ids
|
||||
batch.input_ids = input_ids
|
||||
|
||||
forward_ns = start_decode - start
|
||||
decode_ns = time.time_ns() - start_decode
|
||||
|
|
Loading…
Reference in New Issue