feat(server): allow gpt-neox models with odd vocab sizes to be sharded (#48)
This commit is contained in:
parent
404ed7a1f6
commit
2ad895a6cc
|
@ -26,7 +26,7 @@ to power Bloom, BloomZ and MT0-XXL api-inference widgets.
|
|||
- [MT0-XXL](https://huggingface.co/bigscience/mt0-xxl)
|
||||
- ~~[Galactica](https://huggingface.co/facebook/galactica-120b)~~ (deactivated)
|
||||
- [SantaCoder](https://huggingface.co/bigcode/santacoder)
|
||||
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision refs/pr/13`
|
||||
- [GPT-Neox 20B](https://huggingface.co/EleutherAI/gpt-neox-20b): use `--revision pr/13`
|
||||
|
||||
Other models are supported on a best effort basis using:
|
||||
|
||||
|
|
|
@ -145,7 +145,7 @@ class GPTNeoxSharded(GPTNeox):
|
|||
start = rank * block_size
|
||||
stop = (rank + 1) * block_size
|
||||
tensor = slice_[start:stop]
|
||||
elif name == "embed_out.weight":
|
||||
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
|
||||
size = slice_.get_shape()[0]
|
||||
block_size = size // world_size
|
||||
start = rank * block_size
|
||||
|
@ -176,9 +176,9 @@ class GPTNeoxSharded(GPTNeox):
|
|||
)
|
||||
|
||||
if (
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
type(module)
|
||||
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||||
and param_name == "weight"
|
||||
):
|
||||
tensor = Int8Params(
|
||||
tensor,
|
||||
|
@ -229,16 +229,24 @@ class GPTNeoxSharded(GPTNeox):
|
|||
def forward(
|
||||
self, input_ids, attention_mask, position_ids, past_key_values: Optional = None
|
||||
):
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
if self.model.gpt_neox.tp_embeddings:
|
||||
outputs = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
# Logits are sharded, so we need to gather them
|
||||
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||||
torch.distributed.all_gather(
|
||||
logits, outputs.logits, group=self.process_group
|
||||
)
|
||||
logits = torch.cat(logits, dim=2)
|
||||
|
||||
return logits, outputs.past_key_values
|
||||
return logits, outputs.past_key_values
|
||||
# While the model itself is sharded, the embeddings might not as they might not be dividable by num-shard
|
||||
else:
|
||||
return super(GPTNeoxSharded, self).forward(
|
||||
input_ids, attention_mask, position_ids, past_key_values
|
||||
)
|
||||
|
|
|
@ -91,7 +91,7 @@ class NextTokenChooser:
|
|||
top_p=pb.top_p,
|
||||
do_sample=pb.do_sample,
|
||||
seed=pb.seed,
|
||||
device=str(device),
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue