245 lines
8.3 KiB
Python
245 lines
8.3 KiB
Python
import torch
|
|
import torch.distributed
|
|
|
|
from pathlib import Path
|
|
from accelerate import init_empty_weights
|
|
from opentelemetry import trace
|
|
from safetensors import safe_open
|
|
from transformers import AutoTokenizer, AutoConfig
|
|
from typing import Optional, List
|
|
|
|
from text_generation_server.models import FlashCausalLM
|
|
from text_generation_server.models.custom_modeling.flash_rw_modeling import (
|
|
RWConfig,
|
|
FlashRWForCausalLM,
|
|
TensorParallelEmbedding,
|
|
TensorParallelRowLinear,
|
|
TensorParallelColumnLinear,
|
|
)
|
|
from text_generation_server.utils import (
|
|
initialize_torch_distributed,
|
|
weight_files,
|
|
download_weights,
|
|
weight_hub_files,
|
|
LocalEntryNotFoundError,
|
|
)
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
class FlashRW(FlashCausalLM):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
quantize: Optional[str] = None,
|
|
trust_remote_code: bool = False,
|
|
):
|
|
if torch.cuda.is_available():
|
|
device = torch.device("cuda")
|
|
dtype = torch.bfloat16
|
|
else:
|
|
raise NotImplementedError("RW is only available on GPU")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
padding_side="left",
|
|
truncation_side="left",
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
config = RWConfig.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
)
|
|
|
|
# We do not use from_pretrained as it is too slow
|
|
try:
|
|
filenames = weight_files(model_id, revision, ".bin")
|
|
# Local files not found
|
|
except LocalEntryNotFoundError:
|
|
hub_files = weight_hub_files(model_id, revision, ".bin")
|
|
filenames = download_weights(hub_files, model_id, revision)
|
|
|
|
with init_empty_weights():
|
|
model = FlashRWForCausalLM(config)
|
|
|
|
self.load_weights(
|
|
model,
|
|
filenames,
|
|
quantize,
|
|
device,
|
|
dtype,
|
|
)
|
|
|
|
super(FlashCausalLM, self).__init__(
|
|
model=model.to(device),
|
|
tokenizer=tokenizer,
|
|
requires_padding=False,
|
|
dtype=dtype,
|
|
device=device,
|
|
)
|
|
|
|
@staticmethod
|
|
def load_weights(
|
|
model: FlashRWForCausalLM,
|
|
filenames: List[Path],
|
|
quantize: Optional[str],
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
):
|
|
for filename in filenames:
|
|
state_dict = torch.load(filename, map_location="cpu")
|
|
for key, value in state_dict.items():
|
|
value = value.to(device if quantize is None else "cpu").to(dtype)
|
|
|
|
module_name, param_name = key.rsplit(".", 1)
|
|
module = model.get_submodule(module_name)
|
|
|
|
try:
|
|
current_parameter_tensor = module._parameters[param_name]
|
|
if current_parameter_tensor.shape != value.shape:
|
|
raise ValueError(
|
|
f"Name {key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
|
)
|
|
module._parameters[param_name] = value
|
|
except KeyError:
|
|
module._buffers[param_name] = value
|
|
|
|
del value
|
|
|
|
torch.cuda.empty_cache()
|
|
model.post_load_weights(quantize)
|
|
|
|
|
|
class FlashRWSharded(FlashRW):
|
|
def __init__(
|
|
self,
|
|
model_id: str,
|
|
revision: Optional[str] = None,
|
|
quantize: Optional[str] = None,
|
|
trust_remote_code: bool = False,
|
|
):
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
|
if torch.cuda.is_available():
|
|
device = torch.device(f"cuda:{rank}")
|
|
dtype = torch.bfloat16
|
|
else:
|
|
raise NotImplementedError("FlashRW is only available on GPU")
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_id,
|
|
revision=revision,
|
|
padding_side="left",
|
|
truncation_side="left",
|
|
trust_remote_code=trust_remote_code,
|
|
)
|
|
|
|
config = RWConfig.from_pretrained(
|
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
)
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
|
|
with init_empty_weights():
|
|
model = FlashRWForCausalLM(config, self.process_group)
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
self.load_weights(
|
|
model,
|
|
filenames,
|
|
quantize=quantize,
|
|
device=device,
|
|
dtype=dtype,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
torch.distributed.barrier(group=self.process_group)
|
|
super(FlashCausalLM, self).__init__(
|
|
model=model.to(device),
|
|
tokenizer=tokenizer,
|
|
requires_padding=False,
|
|
dtype=dtype,
|
|
device=device,
|
|
rank=rank,
|
|
world_size=world_size,
|
|
)
|
|
|
|
@staticmethod
|
|
def load_weights(
|
|
model,
|
|
filenames: List[str],
|
|
quantize: Optional[str],
|
|
device: torch.device,
|
|
dtype: torch.dtype,
|
|
rank: int,
|
|
world_size: int,
|
|
):
|
|
parameters = dict(model.named_parameters())
|
|
for file in filenames:
|
|
with safe_open(
|
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
|
) as f:
|
|
for name in f.keys():
|
|
module_name, param_name = name.rsplit(".", 1)
|
|
module = model.get_submodule(module_name)
|
|
|
|
current_parameter_tensor = parameters.get(name, None)
|
|
|
|
slice_ = f.get_slice(name)
|
|
|
|
if isinstance(module, TensorParallelColumnLinear):
|
|
size = slice_.get_shape()[0]
|
|
block_size = size // world_size
|
|
start = rank * block_size
|
|
stop = (rank + 1) * block_size
|
|
tensor = slice_[start:stop]
|
|
elif isinstance(module, TensorParallelRowLinear):
|
|
if param_name == "weight":
|
|
size = slice_.get_shape()[1]
|
|
block_size = size // world_size
|
|
start = rank * block_size
|
|
stop = (rank + 1) * block_size
|
|
tensor = slice_[:, start:stop]
|
|
else:
|
|
tensor = slice_[:]
|
|
# XXX: Hack for Rowlinear to add the bias only once.
|
|
if rank != 0:
|
|
tensor = torch.zeros_like(tensor)
|
|
elif isinstance(module, TensorParallelEmbedding):
|
|
size = slice_.get_shape()[0]
|
|
block_size = size // world_size
|
|
start = rank * block_size
|
|
stop = (rank + 1) * block_size
|
|
tensor = slice_[start:stop]
|
|
elif name == "lm_head.weight" and model.transformer.tp_embeddings:
|
|
size = slice_.get_shape()[0]
|
|
block_size = size // world_size
|
|
start = rank * block_size
|
|
stop = (rank + 1) * block_size
|
|
tensor = slice_[start:stop]
|
|
else:
|
|
try:
|
|
tensor = slice_[:]
|
|
except:
|
|
tensor = f.get_tensor(name)
|
|
|
|
if (
|
|
current_parameter_tensor is not None
|
|
and current_parameter_tensor.shape != tensor.shape
|
|
):
|
|
raise ValueError(
|
|
f"Name {name} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
|
)
|
|
|
|
tensor = tensor.contiguous().to(dtype)
|
|
|
|
if current_parameter_tensor is not None:
|
|
module._parameters[param_name] = tensor
|
|
else:
|
|
module._buffers[param_name] = tensor
|
|
|
|
model.post_load_weights(quantize)
|