feat(server): optimize dist ops (#434)
This commit is contained in:
parent
abd58ff82c
commit
e496c9ba5b
|
@ -265,7 +265,8 @@ class FlashNeoXLayer(nn.Module):
|
||||||
mlp_output = self.mlp(ln2_hidden_states)
|
mlp_output = self.mlp(ln2_hidden_states)
|
||||||
intermediate = mlp_output + attn_output
|
intermediate = mlp_output + attn_output
|
||||||
|
|
||||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||||
|
|
||||||
return intermediate + hidden_states, None
|
return intermediate + hidden_states, None
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -440,7 +440,8 @@ class FlashRWLayer(nn.Module):
|
||||||
mlp_output = self.mlp(ln_hidden_states)
|
mlp_output = self.mlp(ln_hidden_states)
|
||||||
intermediate = mlp_output + attn_output
|
intermediate = mlp_output + attn_output
|
||||||
|
|
||||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||||
|
|
||||||
return intermediate, residual
|
return intermediate, residual
|
||||||
else:
|
else:
|
||||||
|
@ -524,7 +525,8 @@ class FlashRWLargeLayer(nn.Module):
|
||||||
|
|
||||||
intermediate = attn_output + mlp_output
|
intermediate = attn_output + mlp_output
|
||||||
|
|
||||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||||
|
|
||||||
return intermediate, residual
|
return intermediate, residual
|
||||||
|
|
||||||
|
|
|
@ -346,7 +346,9 @@ class FlashSantacoderModel(nn.Module):
|
||||||
pre_allocate_past_size: Optional[int] = None,
|
pre_allocate_past_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
hidden_states = self.wte(input_ids) + self.wpe(position_ids)
|
||||||
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
|
||||||
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(hidden_states, group=self.process_group)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if past_key_values is None:
|
if past_key_values is None:
|
||||||
|
|
|
@ -158,8 +158,33 @@ class TensorParallelHead(SuperLayer):
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
|
world_size = self.process_group.size()
|
||||||
|
if world_size == 1:
|
||||||
|
return super().forward(input)
|
||||||
|
|
||||||
|
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
|
||||||
|
out_dim = self.linear.weight.shape[0]
|
||||||
|
|
||||||
|
if input.shape[0] == 1:
|
||||||
|
world_out = input.new_empty(1, out_dim * world_size)
|
||||||
|
local_out = input.new_empty(1, out_dim)
|
||||||
|
gather_input = local_out
|
||||||
|
else:
|
||||||
|
world_out = input.new_empty(out_dim * world_size, input.shape[0])
|
||||||
|
gather_input = input.new_empty(out_dim, input.shape[0])
|
||||||
|
local_out = gather_input.T
|
||||||
|
|
||||||
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_out, gather_input, group=self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
|
if input.shape[0] == 1:
|
||||||
|
return world_out
|
||||||
|
return world_out.T
|
||||||
|
|
||||||
output = super().forward(input)
|
output = super().forward(input)
|
||||||
# Logits are sharded, so we need to gather them
|
|
||||||
world_output = [
|
world_output = [
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
]
|
]
|
||||||
|
@ -211,7 +236,8 @@ class TensorParallelRowLinear(SuperLayer):
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if self.process_group.size() > 1:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -245,7 +271,7 @@ class TensorParallelEmbedding(nn.Module):
|
||||||
input - self.min_id,
|
input - self.min_id,
|
||||||
)
|
)
|
||||||
out = torch.nn.functional.embedding(input, self.weight)
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
if self.reduce:
|
if self.reduce and self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue