Saving some VRAM. (#2790)
* Saving some VRAM. - 8B on 4xL4 attention=flashdecoding . Before 4.28GB left, After 4.32GB left, so 400MB saved. - Effect not as visible on attention=flashinfer and n_shard=1. I suspect it's linked to the torch allocator. * Adding assertion.
This commit is contained in:
parent
2003d8be0c
commit
b57f370386
|
@ -1389,29 +1389,48 @@ class FlashCausalLM(Model):
|
||||||
]
|
]
|
||||||
|
|
||||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
|
||||||
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
|
||||||
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
|
||||||
input_lengths = [max_s] * bs
|
input_lengths = [max_s] * bs
|
||||||
cache_lengths = [0] * bs
|
cache_lengths = [0] * bs
|
||||||
input_lengths_tensor = (
|
if max_bs is None:
|
||||||
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
)
|
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||||
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
|
||||||
block_tables = torch.arange(
|
input_lengths_tensor = (
|
||||||
max_bt, dtype=torch.int32, device=self.device
|
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
|
||||||
).repeat(bs)
|
)
|
||||||
block_tables = block_tables.reshape((bs, max_bt))
|
cache_lengths_tensor = torch.zeros(
|
||||||
|
bs, dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
block_tables = torch.arange(
|
||||||
|
max_bt, dtype=torch.int32, device=self.device
|
||||||
|
).repeat(bs)
|
||||||
|
block_tables = block_tables.reshape((bs, max_bt))
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = block_tables_to_ragged(
|
||||||
|
block_tables=block_tables,
|
||||||
|
input_lengths=input_lengths,
|
||||||
|
cache_lengths=cache_lengths,
|
||||||
|
input_lengths_tensor=input_lengths_tensor,
|
||||||
|
cache_lengths_tensor=cache_lengths_tensor,
|
||||||
|
max_current_length=max_s,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if bs > max_bs:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
|
||||||
|
)
|
||||||
|
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
|
||||||
|
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
|
||||||
|
if ATTENTION == "flashinfer":
|
||||||
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
|
||||||
|
else:
|
||||||
|
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
|
||||||
|
slots = self.cuda_graphs[max_bs]["slots"][:bs]
|
||||||
|
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
|
||||||
|
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
|
||||||
|
|
||||||
if ATTENTION == "flashinfer":
|
if ATTENTION == "flashinfer":
|
||||||
block_tables = block_tables_to_ragged(
|
|
||||||
block_tables=block_tables,
|
|
||||||
input_lengths=input_lengths,
|
|
||||||
cache_lengths=cache_lengths,
|
|
||||||
input_lengths_tensor=input_lengths_tensor,
|
|
||||||
cache_lengths_tensor=cache_lengths_tensor,
|
|
||||||
max_current_length=max_s,
|
|
||||||
)
|
|
||||||
from text_generation_server.layers.attention.flashinfer import (
|
from text_generation_server.layers.attention.flashinfer import (
|
||||||
create_decode_state_cuda_graphs,
|
create_decode_state_cuda_graphs,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue