diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 0fe43bcb..c045f16e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -265,7 +265,8 @@ class FlashNeoXLayer(nn.Module): mlp_output = self.mlp(ln2_hidden_states) intermediate = mlp_output + attn_output - torch.distributed.all_reduce(intermediate, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate + hidden_states, None else: diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 55195162..af9fa548 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -440,7 +440,8 @@ class FlashRWLayer(nn.Module): mlp_output = self.mlp(ln_hidden_states) intermediate = mlp_output + attn_output - torch.distributed.all_reduce(intermediate, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual else: @@ -524,7 +525,8 @@ class FlashRWLargeLayer(nn.Module): intermediate = attn_output + mlp_output - torch.distributed.all_reduce(intermediate, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(intermediate, group=self.process_group) return intermediate, residual diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index 888a6066..fcf6be68 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -346,7 +346,9 @@ class FlashSantacoderModel(nn.Module): pre_allocate_past_size: Optional[int] = None, ): hidden_states = self.wte(input_ids) + self.wpe(position_ids) - torch.distributed.all_reduce(hidden_states, group=self.process_group) + + if self.process_group.size() > 1: + torch.distributed.all_reduce(hidden_states, group=self.process_group) # Prefill if past_key_values is None: diff --git a/server/text_generation_server/utils/layers.py b/server/text_generation_server/utils/layers.py index ee32a0dc..93865d52 100644 --- a/server/text_generation_server/utils/layers.py +++ b/server/text_generation_server/utils/layers.py @@ -158,8 +158,33 @@ class TensorParallelHead(SuperLayer): ) def forward(self, input: torch.Tensor) -> torch.Tensor: + world_size = self.process_group.size() + if world_size == 1: + return super().forward(input) + + if len(input.shape) == 2 and isinstance(self.linear, FastLinear): + out_dim = self.linear.weight.shape[0] + + if input.shape[0] == 1: + world_out = input.new_empty(1, out_dim * world_size) + local_out = input.new_empty(1, out_dim) + gather_input = local_out + else: + world_out = input.new_empty(out_dim * world_size, input.shape[0]) + gather_input = input.new_empty(out_dim, input.shape[0]) + local_out = gather_input.T + + torch.mm(input, self.linear.weight.T, out=local_out) + + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + + if input.shape[0] == 1: + return world_out + return world_out.T + output = super().forward(input) - # Logits are sharded, so we need to gather them world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] @@ -211,7 +236,8 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor) -> torch.Tensor: out = super().forward(input) - torch.distributed.all_reduce(out, group=self.process_group) + if self.process_group.size() > 1: + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -245,7 +271,7 @@ class TensorParallelEmbedding(nn.Module): input - self.min_id, ) out = torch.nn.functional.embedding(input, self.weight) - if self.reduce: + if self.reduce and self.process_group.size() > 1: torch.distributed.all_reduce(out, group=self.process_group) return out