feat(server): allow gpt-neox models with odd vocab sizes to be sharded (#48)

This commit is contained in:
OlivierDehaene 2023-02-01 14:43:59 +01:00 committed by GitHub
parent 404ed7a1f6
commit 2ad895a6cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 25 additions and 17 deletions

View File

@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13`
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision pr/13`
Other models are supported on a best effort basis using:

View File

@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox):
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif name == "embed_out.weight":
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox):
)
if (
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
type(module)
in [TensorParallelRowLinear, TensorParallelColumnLinear]
and param_name == "weight"
):
tensor = Int8Params(
tensor,
@ -229,16 +229,24 @@ class GPTNeoxSharded(GPTNeox):
def forward(
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
):
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
if self.model.gpt_neox.tp_embeddings:
outputs = self.model.forward(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=True,
)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
logits = torch.cat(logits, dim=2)
# Logits are sharded, so we need to gather them
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
torch.distributed.all_gather(
logits, outputs.logits, group=self.process_group
)
logits = torch.cat(logits, dim=2)
return logits, outputs.past_key_values
return logits, outputs.past_key_values
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
else:
return super(GPTNeoxSharded, self).forward(
input_ids, attention_mask, position_ids, past_key_values
)

View File

@ -91,7 +91,7 @@ class NextTokenChooser:
top_p=pb.top_p,
do_sample=pb.do_sample,
seed=pb.seed,
device=str(device),
device=device,
)