259 lines
9.8 KiB
Python
259 lines
9.8 KiB
Python
|
import torch
|
||
|
import torch.distributed
|
||
|
|
||
|
from typing import List, Optional, Tuple
|
||
|
|
||
|
from accelerate import init_empty_weights
|
||
|
from safetensors import safe_open
|
||
|
from transformers import (
|
||
|
AutoTokenizer,
|
||
|
AutoModelForSeq2SeqLM,
|
||
|
AutoConfig,
|
||
|
)
|
||
|
from transformers.models.t5.parallel_layers import (
|
||
|
TensorParallelColumnLinear,
|
||
|
TensorParallelEmbedding,
|
||
|
TensorParallelRowLinear,
|
||
|
)
|
||
|
|
||
|
from text_generation.models import Seq2SeqLM
|
||
|
from text_generation.utils import (
|
||
|
initialize_torch_distributed,
|
||
|
weight_files,
|
||
|
download_weights,
|
||
|
)
|
||
|
|
||
|
HAS_BITS_AND_BYTES = True
|
||
|
try:
|
||
|
import bitsandbytes as bnb
|
||
|
from bitsandbytes.nn import Int8Params
|
||
|
except Exception as e:
|
||
|
HAS_BITS_AND_BYTES = False
|
||
|
|
||
|
|
||
|
class T5Sharded(Seq2SeqLM):
|
||
|
def __init__(
|
||
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
||
|
):
|
||
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
||
|
self.master = self.rank == 0
|
||
|
if torch.cuda.is_available():
|
||
|
device = torch.device(f"cuda:{self.rank}")
|
||
|
dtype = torch.bfloat16
|
||
|
else:
|
||
|
device = torch.device("cpu")
|
||
|
dtype = torch.float32
|
||
|
|
||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||
|
model_id, revision=revision, padding_side="left"
|
||
|
)
|
||
|
|
||
|
config = AutoConfig.from_pretrained(
|
||
|
model_id, revision=revision, tp_parallel=True
|
||
|
)
|
||
|
tokenizer.bos_token_id = config.decoder_start_token_id
|
||
|
|
||
|
# Only master download weights
|
||
|
if self.master:
|
||
|
download_weights(model_id, revision=revision, extension=".safetensors")
|
||
|
|
||
|
torch.distributed.barrier(group=self.process_group)
|
||
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
||
|
if not filenames:
|
||
|
raise ValueError("No safetensors weights found")
|
||
|
|
||
|
with init_empty_weights():
|
||
|
model = AutoModelForSeq2SeqLM.from_config(config)
|
||
|
|
||
|
torch.distributed.barrier(group=self.process_group)
|
||
|
self.load_weights(
|
||
|
model,
|
||
|
filenames,
|
||
|
quantize=quantize,
|
||
|
device=device,
|
||
|
rank=self.rank,
|
||
|
world_size=self.world_size,
|
||
|
)
|
||
|
self.model = model.eval().to(dtype)
|
||
|
torch.distributed.barrier(group=self.process_group)
|
||
|
super(Seq2SeqLM, self).__init__(
|
||
|
tokenizer=tokenizer,
|
||
|
device=device,
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def load_weights(
|
||
|
model,
|
||
|
filenames: List[str],
|
||
|
quantize: bool,
|
||
|
device: torch.device,
|
||
|
rank: int,
|
||
|
world_size: int,
|
||
|
):
|
||
|
parameters = dict(model.named_parameters())
|
||
|
for file in filenames:
|
||
|
with safe_open(
|
||
|
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||
|
) as f:
|
||
|
for name in f.keys():
|
||
|
module_name, param_name = name.rsplit(".", 1)
|
||
|
module = model.get_submodule(module_name)
|
||
|
|
||
|
current_parameter_tensor = parameters.get(name, None)
|
||
|
|
||
|
slice_ = f.get_slice(name)
|
||
|
|
||
|
if isinstance(module, TensorParallelColumnLinear):
|
||
|
size = slice_.get_shape()[0]
|
||
|
block_size = size // world_size
|
||
|
start = rank * block_size
|
||
|
stop = (rank + 1) * block_size
|
||
|
tensor = slice_[start:stop]
|
||
|
elif isinstance(module, TensorParallelRowLinear):
|
||
|
if param_name == "weight":
|
||
|
size = slice_.get_shape()[1]
|
||
|
block_size = size // world_size
|
||
|
start = rank * block_size
|
||
|
stop = (rank + 1) * block_size
|
||
|
tensor = slice_[:, start:stop]
|
||
|
else:
|
||
|
tensor = slice_[:]
|
||
|
# XXX: Hack for Rowlinear to add the bias only once.
|
||
|
if rank != 0:
|
||
|
tensor = torch.zeros_like(tensor)
|
||
|
elif isinstance(module, TensorParallelEmbedding):
|
||
|
size = slice_.get_shape()[0]
|
||
|
block_size = size // world_size
|
||
|
start = rank * block_size
|
||
|
stop = (rank + 1) * block_size
|
||
|
tensor = slice_[start:stop]
|
||
|
elif name == "lm_head.weight":
|
||
|
size = slice_.get_shape()[0]
|
||
|
block_size = size // world_size
|
||
|
start = rank * block_size
|
||
|
stop = (rank + 1) * block_size
|
||
|
tensor = slice_[start:stop]
|
||
|
elif "relative_attention_bias.weight" in name:
|
||
|
size = slice_.get_shape()[1]
|
||
|
block_size = size // world_size
|
||
|
start = rank * block_size
|
||
|
stop = (rank + 1) * block_size
|
||
|
tensor = slice_[:, start:stop]
|
||
|
else:
|
||
|
try:
|
||
|
tensor = slice_[:]
|
||
|
except:
|
||
|
tensor = f.get_tensor(name)
|
||
|
|
||
|
if (
|
||
|
current_parameter_tensor is not None
|
||
|
and current_parameter_tensor.shape != tensor.shape
|
||
|
):
|
||
|
raise ValueError(
|
||
|
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
||
|
)
|
||
|
|
||
|
tensor = tensor.contiguous()
|
||
|
|
||
|
if quantize:
|
||
|
if not HAS_BITS_AND_BYTES:
|
||
|
raise ImportError(
|
||
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||
|
"or you don't have a GPU.\n"
|
||
|
"You can install it with `pip install bitsandbytes`."
|
||
|
)
|
||
|
|
||
|
if (
|
||
|
type(module)
|
||
|
in [TensorParallelRowLinear, TensorParallelColumnLinear]
|
||
|
and param_name == "weight"
|
||
|
):
|
||
|
tensor = Int8Params(
|
||
|
tensor,
|
||
|
has_fp16_weights=False,
|
||
|
requires_grad=False,
|
||
|
).to(device)
|
||
|
state = bnb.MatmulLtState()
|
||
|
state.threshold = 6.0
|
||
|
state.has_fp16_weights = False
|
||
|
state.memory_efficient_backward = False
|
||
|
state.use_pool = True
|
||
|
state.CB = tensor.CB
|
||
|
state.SCB = tensor.SCB
|
||
|
tensor.CB = None
|
||
|
tensor.SCB = None
|
||
|
|
||
|
def replace_linear(state):
|
||
|
def linear(input, weight, bias):
|
||
|
out = bnb.matmul(
|
||
|
input,
|
||
|
weight,
|
||
|
state=state,
|
||
|
threshold=state.threshold,
|
||
|
bias=bias,
|
||
|
)
|
||
|
|
||
|
if state.CB is not None:
|
||
|
# we converted 8-bit row major to turing/ampere format
|
||
|
# in the first inference pass
|
||
|
# we no longer need the row-major weight
|
||
|
del state.CB
|
||
|
weight.data = state.CxB
|
||
|
|
||
|
return out
|
||
|
|
||
|
return linear
|
||
|
|
||
|
module.linear = replace_linear(state)
|
||
|
|
||
|
else:
|
||
|
tensor = tensor.to(device)
|
||
|
|
||
|
if current_parameter_tensor is not None:
|
||
|
module._parameters[param_name] = tensor
|
||
|
else:
|
||
|
module._buffers[param_name] = tensor
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
input_ids,
|
||
|
attention_mask,
|
||
|
decoder_input_ids,
|
||
|
decoder_attention_mask: Optional,
|
||
|
encoder_last_hidden_state: Optional,
|
||
|
past_key_values: Optional = None,
|
||
|
) -> Tuple[
|
||
|
torch.Tensor,
|
||
|
torch.Tensor,
|
||
|
List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
|
||
|
]:
|
||
|
# Model Forward
|
||
|
if past_key_values is not None:
|
||
|
decoder_input_ids = decoder_input_ids[:, -1].unsqueeze(-1)
|
||
|
|
||
|
# Wrap `encoder_last_hidden_state` because for some reason, Transformers does a `encoder_last_hidden_state[0]`
|
||
|
# internally...
|
||
|
if encoder_last_hidden_state is not None:
|
||
|
encoder_last_hidden_state = [encoder_last_hidden_state]
|
||
|
|
||
|
outputs = self.model.forward(
|
||
|
input_ids=input_ids,
|
||
|
attention_mask=attention_mask,
|
||
|
decoder_input_ids=decoder_input_ids,
|
||
|
decoder_attention_mask=decoder_attention_mask,
|
||
|
encoder_outputs=encoder_last_hidden_state,
|
||
|
past_key_values=past_key_values,
|
||
|
use_cache=True,
|
||
|
)
|
||
|
|
||
|
# Logits are sharded, so we need to gather them
|
||
|
logits = [torch.empty_like(outputs.logits) for _ in range(self.world_size)]
|
||
|
torch.distributed.all_gather(logits, outputs.logits, group=self.process_group)
|
||
|
logits = torch.cat(logits, dim=2)
|
||
|
|
||
|
return (
|
||
|
logits,
|
||
|
outputs.encoder_last_hidden_state,
|
||
|
outputs.past_key_values,
|
||
|
)
|