Saving some VRAM. (#2790)

* Saving some VRAM.

- 8B on 4xL4 attention=flashdecoding . Before 4.28GB left, After 4.32GB
  left, so 400MB saved.

- Effect not as visible on attention=flashinfer and n_shard=1. I suspect
  it's linked to the torch allocator.

* Adding assertion.
This commit is contained in:
Nicolas Patry 2024-12-03 08:34:21 +05:30 committed by GitHub
parent 2003d8be0c
commit b57f370386
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 38 additions and 19 deletions

View File

@ -1389,29 +1389,48 @@ class FlashCausalLM(Model):
] ]
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) max_bs = max(self.cuda_graphs.keys()) if self.cuda_graphs else None
position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
slots = torch.arange(bs, dtype=torch.int64, device=self.device)
input_lengths = [max_s] * bs input_lengths = [max_s] * bs
cache_lengths = [0] * bs cache_lengths = [0] * bs
input_lengths_tensor = ( if max_bs is None:
torch.ones(bs, dtype=torch.int32, device=self.device) * max_s input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
) position_ids = torch.zeros(bs, dtype=torch.int32, device=self.device)
cache_lengths_tensor = torch.zeros(bs, dtype=torch.int32, device=self.device) slots = torch.arange(bs, dtype=torch.int64, device=self.device)
block_tables = torch.arange( input_lengths_tensor = (
max_bt, dtype=torch.int32, device=self.device torch.ones(bs, dtype=torch.int32, device=self.device) * max_s
).repeat(bs) )
block_tables = block_tables.reshape((bs, max_bt)) cache_lengths_tensor = torch.zeros(
bs, dtype=torch.int32, device=self.device
)
block_tables = torch.arange(
max_bt, dtype=torch.int32, device=self.device
).repeat(bs)
block_tables = block_tables.reshape((bs, max_bt))
if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
cache_lengths=cache_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor,
max_current_length=max_s,
)
else:
if bs > max_bs:
raise RuntimeError(
"Cuda graphs should be generated in decreasing order size to reduce VRAM usage"
)
input_ids = self.cuda_graphs[max_bs]["input_ids"][:bs]
position_ids = self.cuda_graphs[max_bs]["position_ids"][:bs]
if ATTENTION == "flashinfer":
block_tables = self.cuda_graphs[max_bs]["block_tables"][: bs * max_bt]
else:
block_tables = self.cuda_graphs[max_bs]["block_tables"][:bs]
slots = self.cuda_graphs[max_bs]["slots"][:bs]
input_lengths_tensor = self.cuda_graphs[max_bs]["input_lengths"][:bs]
cache_lengths_tensor = self.cuda_graphs[max_bs]["cache_lengths"][:bs]
if ATTENTION == "flashinfer": if ATTENTION == "flashinfer":
block_tables = block_tables_to_ragged(
block_tables=block_tables,
input_lengths=input_lengths,
cache_lengths=cache_lengths,
input_lengths_tensor=input_lengths_tensor,
cache_lengths_tensor=cache_lengths_tensor,
max_current_length=max_s,
)
from text_generation_server.layers.attention.flashinfer import ( from text_generation_server.layers.attention.flashinfer import (
create_decode_state_cuda_graphs, create_decode_state_cuda_graphs,
) )