Improvments within mamba.
This commit is contained in:
parent
5b6f9259c1
commit
a3c45da0a4
|
@ -51,12 +51,12 @@ async def test_fused_kernel_mamba_all_params(fused_kernel_mamba, response_snapsh
|
||||||
# TODO: fix `Expected x0.dim() == 2 to be true, but got false.`
|
# TODO: fix `Expected x0.dim() == 2 to be true, but got false.`
|
||||||
# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))`
|
# 94: `hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))`
|
||||||
# NOTE: the fast layer norm has strict requirements on the input shape
|
# NOTE: the fast layer norm has strict requirements on the input shape
|
||||||
# @pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
# @pytest.mark.private
|
@pytest.mark.private
|
||||||
# async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
async def test_fused_kernel_mamba_load(fused_kernel_mamba, generate_load, response_snapshot):
|
||||||
# responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
|
responses = await generate_load(fused_kernel_mamba, "Test request", max_new_tokens=10, n=4)
|
||||||
|
|
||||||
# assert len(responses) == 4
|
assert len(responses) == 4
|
||||||
# assert all([r.generated_text == responses[0].generated_text for r in responses])
|
assert all([r.generated_text == responses[0].generated_text for r in responses])
|
||||||
|
|
||||||
# assert responses == response_snapshot
|
assert responses == response_snapshot
|
||||||
|
|
|
@ -311,8 +311,8 @@ class CausalLMBatch(Batch):
|
||||||
end_index = start_index + len(batch)
|
end_index = start_index + len(batch)
|
||||||
|
|
||||||
# We only concatenate batches that did at least one step
|
# We only concatenate batches that did at least one step
|
||||||
if batch.past_key_values is None:
|
# if batch.past_key_values is None:
|
||||||
raise ValueError("only concatenate prefilled batches")
|
# raise ValueError("only concatenate prefilled batches")
|
||||||
|
|
||||||
# Create empty tensor
|
# Create empty tensor
|
||||||
# input_ids is always of shape [batch_size, 1]
|
# input_ids is always of shape [batch_size, 1]
|
||||||
|
|
|
@ -91,8 +91,9 @@ class ResidualBlock(nn.Module):
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states, _ = self.layer_norm(hidden_states.squeeze(0))
|
shape = hidden_states.shape
|
||||||
hidden_states = residual + self.mamba_block(hidden_states.unsqueeze(0))
|
hidden_states, _ = self.layer_norm(hidden_states.view(-1, shape[-1]))
|
||||||
|
hidden_states = residual + self.mamba_block(hidden_states.view(*shape))
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
class MambaModel(nn.Module):
|
class MambaModel(nn.Module):
|
||||||
|
@ -114,5 +115,6 @@ class MambaModel(nn.Module):
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
hidden_states = block(hidden_states)
|
hidden_states = block(hidden_states)
|
||||||
|
|
||||||
final_hidden_states, _ = self.norm_f(hidden_states.squeeze(0))
|
shape = hidden_states.shape
|
||||||
return self.lm_head(final_hidden_states.unsqueeze(0)), input_ids
|
final_hidden_states, _ = self.norm_f(hidden_states.view(-1, shape[-1]))
|
||||||
|
return self.lm_head(final_hidden_states.view(*shape)), input_ids
|
||||||
|
|
|
@ -30,23 +30,9 @@ from text_generation_server.utils.tokens import batch_top_tokens, Sampling
|
||||||
|
|
||||||
|
|
||||||
class MambaCausalLMBatch(CausalLMBatch):
|
class MambaCausalLMBatch(CausalLMBatch):
|
||||||
past_input_ids: Optional[torch.Tensor]
|
pass
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.past_input_ids = None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pb(
|
|
||||||
cls,
|
|
||||||
pb: generate_pb2.Batch,
|
|
||||||
tokenizer: PreTrainedTokenizerBase,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
) -> "CausalLMBatch":
|
|
||||||
batch = super().from_pb(pb=pb, tokenizer=tokenizer, dtype=dtype, device=device)
|
|
||||||
batch.keys_head_dim_last = False
|
|
||||||
return batch
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba(Model):
|
class Mamba(Model):
|
||||||
|
@ -119,7 +105,7 @@ class Mamba(Model):
|
||||||
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
def generate_token(self, batch) -> Tuple[List[Any], Optional[Any], Tuple[int, int]]:
|
||||||
start = time.time_ns()
|
start = time.time_ns()
|
||||||
|
|
||||||
input_ids = batch.past_input_ids if batch.past_input_ids is not None else batch.input_ids
|
input_ids = batch.input_ids
|
||||||
|
|
||||||
logits, past_input_ids = self.model(input_ids)[:2]
|
logits, past_input_ids = self.model(input_ids)[:2]
|
||||||
|
|
||||||
|
@ -151,6 +137,8 @@ class Mamba(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
# For each member of the batch
|
# For each member of the batch
|
||||||
|
next_token_ids = []
|
||||||
|
kept_batch_ids = []
|
||||||
for i, (
|
for i, (
|
||||||
request,
|
request,
|
||||||
input_length,
|
input_length,
|
||||||
|
@ -235,7 +223,8 @@ class Mamba(Model):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prefill_tokens = None
|
prefill_tokens = None
|
||||||
past_input_ids = torch.cat([past_input_ids, next_token_id], dim=1)
|
next_token_ids.append(next_token_id)
|
||||||
|
kept_batch_ids.append(i)
|
||||||
|
|
||||||
|
|
||||||
if top_n_tokens > 0:
|
if top_n_tokens > 0:
|
||||||
|
@ -278,13 +267,18 @@ class Mamba(Model):
|
||||||
batch.read_offsets[i] = read_offset
|
batch.read_offsets[i] = read_offset
|
||||||
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
batch.max_input_length = max(batch.max_input_length, new_input_length)
|
||||||
|
|
||||||
|
# Merge all new tokens
|
||||||
|
if next_token_ids:
|
||||||
|
next_token_ids = torch.cat(next_token_ids, dim=0)
|
||||||
|
past_input_ids = torch.cat([past_input_ids[kept_batch_ids], next_token_ids], dim=1)
|
||||||
|
|
||||||
# We finished all generations in the batch; there is no next batch
|
# We finished all generations in the batch; there is no next batch
|
||||||
if stopped:
|
if stopped:
|
||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
return generations, None, (forward_ns, decode_ns)
|
return generations, None, (forward_ns, decode_ns)
|
||||||
|
|
||||||
batch.past_input_ids = past_input_ids
|
batch.input_ids = input_ids
|
||||||
|
|
||||||
forward_ns = start_decode - start
|
forward_ns = start_decode - start
|
||||||
decode_ns = time.time_ns() - start_decode
|
decode_ns = time.time_ns() - start_decode
|
||||||
|
|
Loading…
Reference in New Issue