2023-03-24 07:02:14 -06:00
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
|
|
|
|
from accelerate import init_empty_weights
|
|
|
|
from opentelemetry import trace
|
|
|
|
from safetensors import safe_open
|
2023-04-03 11:06:42 -06:00
|
|
|
from transformers import AutoTokenizer, AutoConfig
|
2023-04-12 09:18:08 -06:00
|
|
|
from typing import Optional, List
|
2023-03-24 07:02:14 -06:00
|
|
|
|
2023-04-03 11:06:42 -06:00
|
|
|
from text_generation_server.models import FlashCausalLM
|
|
|
|
from text_generation_server.models.custom_modeling.flash_neox_modeling import (
|
2023-03-24 07:02:14 -06:00
|
|
|
FlashGPTNeoXForCausalLM,
|
|
|
|
TensorParallelEmbedding,
|
|
|
|
TensorParallelRowLinear,
|
|
|
|
TensorParallelColumnLinear,
|
|
|
|
)
|
|
|
|
from text_generation_server.utils import (
|
|
|
|
initialize_torch_distributed,
|
|
|
|
weight_files,
|
|
|
|
)
|
|
|
|
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
|
|
|
2023-04-03 11:06:42 -06:00
|
|
|
class FlashNeoX(FlashCausalLM):
|
2023-03-24 07:02:14 -06:00
|
|
|
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
|
|
|
|
super(FlashNeoX, self).__init__(
|
2023-04-03 11:06:42 -06:00
|
|
|
FlashGPTNeoXForCausalLM, model_id, revision, quantize
|
2023-03-24 07:02:14 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class FlashNeoXSharded(FlashNeoX):
|
|
|
|
def __init__(
|
|
|
|
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
|
|
|
|
):
|
2023-04-21 07:59:18 -06:00
|
|
|
self.past_pad = None
|
2023-03-24 07:02:14 -06:00
|
|
|
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
|
|
|
|
self.master = self.rank == 0
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
device = torch.device(f"cuda:{self.rank}")
|
2023-05-09 10:26:19 -06:00
|
|
|
dtype = torch.float16
|
2023-03-24 07:02:14 -06:00
|
|
|
else:
|
|
|
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
2023-04-09 12:22:27 -06:00
|
|
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
2023-03-24 07:02:14 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained(
|
2023-04-09 12:22:27 -06:00
|
|
|
model_id,
|
|
|
|
revision=revision,
|
2023-03-24 07:02:14 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
|
|
|
|
|
|
with init_empty_weights():
|
2023-04-09 12:22:27 -06:00
|
|
|
model = FlashGPTNeoXForCausalLM(config, self.process_group)
|
2023-03-24 07:02:14 -06:00
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
self.load_weights(
|
|
|
|
model,
|
|
|
|
filenames,
|
2023-04-19 04:51:11 -06:00
|
|
|
quantize=quantize,
|
2023-03-24 07:02:14 -06:00
|
|
|
device=device,
|
2023-04-12 09:18:08 -06:00
|
|
|
dtype=dtype,
|
2023-03-24 07:02:14 -06:00
|
|
|
rank=self.rank,
|
|
|
|
world_size=self.world_size,
|
|
|
|
)
|
2023-04-19 04:51:11 -06:00
|
|
|
self.model = model.eval().to(device)
|
2023-03-24 07:02:14 -06:00
|
|
|
torch.distributed.barrier(group=self.process_group)
|
2023-04-03 11:06:42 -06:00
|
|
|
super(FlashCausalLM, self).__init__(
|
2023-03-24 07:02:14 -06:00
|
|
|
tokenizer=tokenizer,
|
2023-04-21 07:36:29 -06:00
|
|
|
requires_padding=False,
|
|
|
|
dtype=dtype,
|
2023-03-24 07:02:14 -06:00
|
|
|
device=device,
|
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load_weights(
|
|
|
|
model,
|
|
|
|
filenames: List[str],
|
2023-04-19 04:51:11 -06:00
|
|
|
quantize: bool,
|
2023-03-24 07:02:14 -06:00
|
|
|
device: torch.device,
|
2023-04-12 09:18:08 -06:00
|
|
|
dtype: torch.dtype,
|
2023-03-24 07:02:14 -06:00
|
|
|
rank: int,
|
|
|
|
world_size: int,
|
|
|
|
):
|
|
|
|
parameters = dict(model.named_parameters())
|
|
|
|
for file in filenames:
|
2023-04-19 04:51:11 -06:00
|
|
|
with safe_open(
|
|
|
|
file, framework="pt", device=str(device) if not quantize else "cpu"
|
|
|
|
) as f:
|
2023-03-24 07:02:14 -06:00
|
|
|
for name in f.keys():
|
|
|
|
module_name, param_name = name.rsplit(".", 1)
|
|
|
|
module = model.get_submodule(module_name)
|
|
|
|
|
|
|
|
current_parameter_tensor = parameters.get(name, None)
|
|
|
|
|
|
|
|
slice_ = f.get_slice(name)
|
|
|
|
|
|
|
|
if isinstance(module, TensorParallelColumnLinear):
|
|
|
|
size = slice_.get_shape()[0]
|
|
|
|
block_size = size // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
tensor = slice_[start:stop]
|
|
|
|
elif isinstance(module, TensorParallelRowLinear):
|
|
|
|
if param_name == "weight":
|
|
|
|
size = slice_.get_shape()[1]
|
|
|
|
block_size = size // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
tensor = slice_[:, start:stop]
|
|
|
|
else:
|
|
|
|
tensor = slice_[:]
|
|
|
|
# XXX: Hack for Rowlinear to add the bias only once.
|
|
|
|
if rank != 0:
|
|
|
|
tensor = torch.zeros_like(tensor)
|
|
|
|
elif isinstance(module, TensorParallelEmbedding):
|
|
|
|
size = slice_.get_shape()[0]
|
|
|
|
block_size = size // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
tensor = slice_[start:stop]
|
|
|
|
elif name == "embed_out.weight" and model.gpt_neox.tp_embeddings:
|
|
|
|
size = slice_.get_shape()[0]
|
|
|
|
block_size = size // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
tensor = slice_[start:stop]
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
tensor = slice_[:]
|
|
|
|
except:
|
|
|
|
tensor = f.get_tensor(name)
|
|
|
|
|
|
|
|
if (
|
|
|
|
current_parameter_tensor is not None
|
|
|
|
and current_parameter_tensor.shape != tensor.shape
|
|
|
|
):
|
|
|
|
raise ValueError(
|
|
|
|
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
|
|
|
)
|
|
|
|
|
2023-04-12 09:18:08 -06:00
|
|
|
tensor = tensor.contiguous().to(dtype)
|
2023-03-24 07:02:14 -06:00
|
|
|
|
|
|
|
if current_parameter_tensor is not None:
|
|
|
|
module._parameters[param_name] = tensor
|
|
|
|
else:
|
|
|
|
module._buffers[param_name] = tensor
|
2023-05-03 03:36:24 -06:00
|
|
|
|
|
|
|
uninitialized_parameters = []
|
|
|
|
for n, p in model.named_parameters():
|
|
|
|
if p.data.device == torch.device("meta"):
|
|
|
|
uninitialized_parameters.append(n)
|
|
|
|
if uninitialized_parameters:
|
|
|
|
raise RuntimeError(
|
|
|
|
f"found uninitialized parameters in model: {uninitialized_parameters}"
|
|
|
|
)
|
|
|
|
|
2023-04-19 04:51:11 -06:00
|
|
|
model.post_load_weights(quantize)
|