Temporary implem of torch.compile on our stuff.

This commit is contained in:
Nicolas Patry 2024-03-21 18:56:40 +00:00
parent 6f15ac60b2
commit 78f87d5a0c
5 changed files with 75 additions and 34 deletions

View File

@ -798,11 +798,13 @@ class FlashCausalLM(Model):
self.device,
)
self.compiled_model = torch.compile(self.model, mode="reduce-overhead")
if ENABLE_CUDA_GRAPHS:
try:
logger.info("Experimental support for Cuda Graphs is enabled")
# Warmup cuda graphs
for bs in [1, 2, 4] + [8 * i for i in range(8)]:
for bs in [1]:
if self.speculate is None or self.speculate + 1 <= bs:
self.cuda_graph_warmup(bs, max_s, max_bt)
except Exception:
@ -881,7 +883,19 @@ class FlashCausalLM(Model):
or cuda_graph is None
or batch.speculative_ids is not None
):
return self.model.forward(
if cu_seqlen_prefill is None:
return self.compiled_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
lm_head_indices=lm_head_indices,
)
return self.model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,

View File

@ -68,6 +68,7 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id, revision)
model = FlashLlamaForCausalLM(config, weights)
# model = torch.compile(model, mode="reduce-overhead")
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model=model,

View File

@ -495,18 +495,33 @@ class BaseFlashMistral(FlashCausalLM):
cuda_graph = self.cuda_graphs.get(padded_bs, None)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
if cu_seqlen_prefill is None:
logits, speculative_logits = self.compiled_model(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
else:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
cu_seqlen_prefill=cu_seqlen_prefill,
kv_cache=kv_cache,
block_tables=block_tables,
slots=slots,
input_lengths=input_lengths,
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
)
if batch.prefill_cache_indices is not None:
batch.prefill_cache_indices = None
return logits, speculative_logits

View File

@ -149,7 +149,19 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
batch = batches[0]
concat_ns = None
generations, next_batch, timings = self.model.generate_token(batch)
torch.profiler._utils._init_for_cuda_graphs()
# prof = torch.profiler.profile()
# if self.model.rank != 0:
if True:
import contextlib
prof = contextlib.nullcontext()
else:
prof = torch.profiler.profile()
with prof:
generations, next_batch, timings = self.model.generate_token(batch)
# if self.model.rank == 0:
# prof.export_chrome_trace(f"out_rank_0.json")
self.cache.set(next_batch)
return generate_pb2.DecodeResponse(

View File

@ -507,27 +507,27 @@ class TensorParallelHead(SuperLayer):
return super().forward(input)
world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
out_dim = self.linear.weight.shape[0]
# if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
# out_dim = self.linear.weight.shape[0]
if input.shape[0] == 1:
world_out = input.new_empty(1, out_dim * world_size)
local_out = input.new_empty(1, out_dim)
gather_input = local_out
else:
world_out = input.new_empty(out_dim * world_size, input.shape[0])
gather_input = input.new_empty(out_dim, input.shape[0])
local_out = gather_input.T
# if input.shape[0] == 1:
# world_out = input.new_empty(1, out_dim * world_size)
# local_out = input.new_empty(1, out_dim)
# gather_input = local_out
# else:
# world_out = input.new_empty(out_dim * world_size, input.shape[0])
# gather_input = input.new_empty(out_dim, input.shape[0])
# local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out)
# torch.mm(input, self.linear.weight.T, out=local_out)
torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
)
# torch.distributed.all_gather_into_tensor(
# world_out, gather_input, group=self.process_group
# )
if input.shape[0] == 1:
return world_out
return world_out.T
# if input.shape[0] == 1:
# return world_out
# return world_out.T
output = super().forward(input)
world_output = [
@ -786,6 +786,7 @@ try:
self._sin_k_cached = None
self.scaling_factor = scaling_factor
self.dynamic_args = None
self._update_cos_sin_cache(torch.float16, inv_freq.device, seqlen=4096)
def forward(
self,
@ -929,8 +930,6 @@ try:
# But later on goes and cast cos/sin to float anyway: https://github.com/Dao-AILab/flash-attention/blob/017716451d446e464dde9aca3a3c1ed2209caaa9/csrc/rotary/rotary_cuda.cu#L29, which looks suboptimal.
dtype = torch.float32
self._update_cos_sin_cache(dtype, position_ids.device, max_s)
cos = torch.index_select(self._cos_cached, 0, position_ids)
sin = torch.index_select(self._sin_cached, 0, position_ids)
# Note: this unsqueeze is not necessary on RoCm + VLLM ROPE implementation, but we leave it as is to avoid yet an other controlflow.