2023-04-03 11:06:42 -06:00
|
|
|
import torch
|
|
|
|
import torch.distributed
|
|
|
|
|
|
|
|
from accelerate import init_empty_weights
|
|
|
|
from opentelemetry import trace
|
2023-04-12 09:18:08 -06:00
|
|
|
from safetensors import safe_open
|
2023-04-03 11:06:42 -06:00
|
|
|
from pathlib import Path
|
2023-04-12 09:18:08 -06:00
|
|
|
from transformers import AutoTokenizer, GPT2Config
|
2023-04-03 11:06:42 -06:00
|
|
|
from typing import Optional, List
|
|
|
|
|
|
|
|
from text_generation_server.models import FlashCausalLM
|
|
|
|
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
|
2023-04-05 11:37:41 -06:00
|
|
|
FlashSantacoderForCausalLM,
|
2023-04-12 09:18:08 -06:00
|
|
|
TensorParallelRowLinear,
|
|
|
|
TensorParallelColumnLinear,
|
|
|
|
TensorParallelEmbedding,
|
2023-04-03 11:06:42 -06:00
|
|
|
)
|
|
|
|
from text_generation_server.utils import (
|
2023-04-12 09:18:08 -06:00
|
|
|
initialize_torch_distributed,
|
2023-04-03 11:06:42 -06:00
|
|
|
weight_files,
|
|
|
|
download_weights,
|
|
|
|
weight_hub_files,
|
|
|
|
LocalEntryNotFoundError,
|
|
|
|
)
|
|
|
|
|
|
|
|
tracer = trace.get_tracer(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
class FlashSantacoder(FlashCausalLM):
|
2023-05-15 15:36:30 -06:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
quantize: Optional[str] = None,
|
2023-05-23 12:40:39 -06:00
|
|
|
trust_remote_code: bool = False,
|
2023-05-15 15:36:30 -06:00
|
|
|
):
|
2023-04-03 11:06:42 -06:00
|
|
|
if torch.cuda.is_available():
|
|
|
|
device = torch.device("cuda")
|
2023-05-09 10:26:19 -06:00
|
|
|
dtype = torch.float16
|
2023-04-03 11:06:42 -06:00
|
|
|
else:
|
|
|
|
raise NotImplementedError("FlashSantacoder is only available on GPU")
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
2023-05-23 12:40:39 -06:00
|
|
|
model_id,
|
|
|
|
revision=revision,
|
|
|
|
padding_side="left",
|
|
|
|
truncation_side="left",
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-04-03 11:06:42 -06:00
|
|
|
)
|
|
|
|
|
2023-04-12 09:18:08 -06:00
|
|
|
config = GPT2Config.from_pretrained(
|
2023-04-05 11:37:41 -06:00
|
|
|
model_id,
|
|
|
|
revision=revision,
|
2023-04-03 11:06:42 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
# We do not use from_pretrained as we modified the model internal module layout
|
|
|
|
try:
|
|
|
|
filenames = weight_files(model_id, revision, ".bin")
|
|
|
|
# Local files not found
|
|
|
|
except LocalEntryNotFoundError:
|
|
|
|
hub_files = weight_hub_files(model_id, revision, ".bin")
|
|
|
|
filenames = download_weights(hub_files, model_id, revision)
|
|
|
|
|
|
|
|
with init_empty_weights():
|
|
|
|
model = FlashSantacoderForCausalLM(config)
|
|
|
|
|
|
|
|
self.load_weights(
|
2023-04-19 04:51:11 -06:00
|
|
|
model,
|
|
|
|
filenames,
|
|
|
|
quantize,
|
|
|
|
device,
|
|
|
|
dtype,
|
|
|
|
config.architectures[0].startswith("GPT2"),
|
2023-04-03 11:06:42 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
super(FlashCausalLM, self).__init__(
|
2023-05-16 15:23:27 -06:00
|
|
|
model=model.to(device),
|
2023-04-21 07:36:29 -06:00
|
|
|
tokenizer=tokenizer,
|
|
|
|
requires_padding=False,
|
|
|
|
dtype=dtype,
|
|
|
|
device=device,
|
2023-04-03 11:06:42 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load_weights(
|
2023-04-05 11:37:41 -06:00
|
|
|
model: FlashSantacoderForCausalLM,
|
|
|
|
filenames: List[Path],
|
2023-05-22 07:05:32 -06:00
|
|
|
quantize: Optional[str],
|
2023-04-09 12:22:27 -06:00
|
|
|
device: torch.device,
|
|
|
|
dtype: torch.dtype,
|
2023-04-12 09:18:08 -06:00
|
|
|
transpose: bool,
|
2023-04-03 11:06:42 -06:00
|
|
|
):
|
|
|
|
for filename in filenames:
|
|
|
|
state_dict = torch.load(filename, map_location="cpu")
|
|
|
|
for key, value in state_dict.items():
|
2023-05-15 15:36:30 -06:00
|
|
|
value = value.to(device if quantize is None else "cpu").to(dtype)
|
2023-04-09 12:22:27 -06:00
|
|
|
|
2023-04-03 11:06:42 -06:00
|
|
|
layer_name = ".".join(key.split(".")[:4])
|
|
|
|
|
|
|
|
# Fused qkv
|
|
|
|
if "q_attn.weight" in key or "kv_attn.weight" in key:
|
2023-04-12 09:18:08 -06:00
|
|
|
final_key = layer_name + ".c_attn.weight"
|
2023-04-03 11:06:42 -06:00
|
|
|
elif "q_attn.bias" in key or "kv_attn.bias" in key:
|
2023-04-12 09:18:08 -06:00
|
|
|
final_key = layer_name + ".c_attn.bias"
|
2023-04-03 11:06:42 -06:00
|
|
|
|
|
|
|
else:
|
|
|
|
final_key = key
|
|
|
|
|
|
|
|
module_name, param_name = final_key.rsplit(".", 1)
|
|
|
|
module = model.get_submodule(module_name)
|
|
|
|
|
|
|
|
try:
|
|
|
|
current_parameter_tensor = module._parameters[param_name]
|
|
|
|
except KeyError:
|
|
|
|
current_parameter_tensor = None
|
|
|
|
|
|
|
|
if current_parameter_tensor is not None:
|
2023-04-12 09:18:08 -06:00
|
|
|
if transpose and (
|
2023-04-05 11:37:41 -06:00
|
|
|
"c_fc.weight" in key
|
|
|
|
or "c_proj.weight" in key
|
|
|
|
or "q_attn.weight" in key
|
|
|
|
or "kv_attn.weight" in key
|
2023-04-12 09:18:08 -06:00
|
|
|
or "c_attn.weight" in key
|
2023-04-05 11:37:41 -06:00
|
|
|
):
|
2023-04-03 11:06:42 -06:00
|
|
|
# Tranpose as we use nn.Linear instead of Conv1D
|
|
|
|
value = value.T
|
|
|
|
|
|
|
|
if current_parameter_tensor.device == torch.device("meta"):
|
|
|
|
# Init qkv
|
2023-04-12 09:18:08 -06:00
|
|
|
if "c_attn.weight" in final_key:
|
2023-04-03 11:06:42 -06:00
|
|
|
module._parameters[param_name] = value.new_empty(
|
2023-04-05 11:37:41 -06:00
|
|
|
(
|
|
|
|
model.transformer.head_size
|
|
|
|
* (model.transformer.num_heads + 2),
|
|
|
|
value.shape[1],
|
|
|
|
)
|
2023-04-03 11:06:42 -06:00
|
|
|
)
|
2023-04-12 09:18:08 -06:00
|
|
|
elif "c_attn.bias" in final_key:
|
2023-04-03 11:06:42 -06:00
|
|
|
module._parameters[param_name] = value.new_empty(
|
2023-04-05 11:37:41 -06:00
|
|
|
(
|
|
|
|
model.transformer.head_size
|
|
|
|
* (model.transformer.num_heads + 2)
|
|
|
|
)
|
2023-04-03 11:06:42 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
# Copy to correct slice
|
|
|
|
if "q_attn.weight" in key:
|
|
|
|
module._parameters[param_name][: value.shape[0]] = value
|
|
|
|
elif "q_attn.bias" in key:
|
|
|
|
module._parameters[param_name][: value.shape[0]] = value
|
|
|
|
elif "kv_attn.weight" in key:
|
|
|
|
module._parameters[param_name][
|
2023-04-05 11:37:41 -06:00
|
|
|
model.transformer.head_size * model.transformer.num_heads :
|
2023-04-03 11:06:42 -06:00
|
|
|
] = value
|
|
|
|
elif "kv_attn.bias" in key:
|
|
|
|
module._parameters[param_name][
|
2023-04-05 11:37:41 -06:00
|
|
|
model.transformer.head_size * model.transformer.num_heads :
|
2023-04-03 11:06:42 -06:00
|
|
|
] = value
|
|
|
|
else:
|
|
|
|
if current_parameter_tensor.shape != value.shape:
|
|
|
|
raise ValueError(
|
|
|
|
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
|
|
|
|
)
|
|
|
|
module._parameters[param_name] = value
|
|
|
|
else:
|
|
|
|
module._buffers[param_name] = value
|
|
|
|
|
2023-04-09 12:22:27 -06:00
|
|
|
del value
|
|
|
|
|
2023-04-03 11:06:42 -06:00
|
|
|
torch.cuda.empty_cache()
|
2023-04-19 04:51:11 -06:00
|
|
|
model.post_load_weights(quantize)
|
2023-04-03 11:06:42 -06:00
|
|
|
|
|
|
|
def decode(self, generated_ids: List[int]) -> str:
|
|
|
|
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
|
|
|
return self.tokenizer.decode(
|
2023-05-03 02:10:34 -06:00
|
|
|
generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
2023-04-03 11:06:42 -06:00
|
|
|
)
|
2023-04-12 09:18:08 -06:00
|
|
|
|
|
|
|
|
|
|
|
class FlashSantacoderSharded(FlashSantacoder):
|
|
|
|
def __init__(
|
2023-05-15 15:36:30 -06:00
|
|
|
self,
|
|
|
|
model_id: str,
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
quantize: Optional[str] = None,
|
2023-05-23 12:40:39 -06:00
|
|
|
trust_remote_code: bool = False,
|
2023-04-12 09:18:08 -06:00
|
|
|
):
|
2023-05-10 07:48:21 -06:00
|
|
|
self.process_group, rank, world_size = initialize_torch_distributed()
|
2023-04-12 09:18:08 -06:00
|
|
|
if torch.cuda.is_available():
|
2023-05-10 07:48:21 -06:00
|
|
|
device = torch.device(f"cuda:{rank}")
|
2023-05-09 10:26:19 -06:00
|
|
|
dtype = torch.float16
|
2023-04-12 09:18:08 -06:00
|
|
|
else:
|
|
|
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
2023-05-23 12:40:39 -06:00
|
|
|
model_id,
|
|
|
|
revision=revision,
|
|
|
|
padding_side="left",
|
|
|
|
truncation_side="left",
|
|
|
|
trust_remote_code=trust_remote_code,
|
2023-04-12 09:18:08 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
config = GPT2Config.from_pretrained(
|
|
|
|
model_id,
|
|
|
|
revision=revision,
|
|
|
|
)
|
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
|
|
|
|
|
|
|
|
with init_empty_weights():
|
|
|
|
model = FlashSantacoderForCausalLM(config, self.process_group)
|
|
|
|
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
self.load_weights(
|
|
|
|
model,
|
|
|
|
filenames,
|
2023-04-19 04:51:11 -06:00
|
|
|
quantize=quantize,
|
2023-04-12 09:18:08 -06:00
|
|
|
device=device,
|
|
|
|
dtype=dtype,
|
2023-05-10 07:48:21 -06:00
|
|
|
rank=rank,
|
|
|
|
world_size=world_size,
|
2023-04-12 09:18:08 -06:00
|
|
|
transpose=config.architectures[0].startswith("GPT2"),
|
|
|
|
)
|
|
|
|
torch.distributed.barrier(group=self.process_group)
|
|
|
|
super(FlashCausalLM, self).__init__(
|
2023-05-16 15:23:27 -06:00
|
|
|
model=model.to(device),
|
2023-04-12 09:18:08 -06:00
|
|
|
tokenizer=tokenizer,
|
2023-04-21 07:36:29 -06:00
|
|
|
requires_padding=False,
|
|
|
|
dtype=dtype,
|
2023-04-12 09:18:08 -06:00
|
|
|
device=device,
|
2023-05-10 07:48:21 -06:00
|
|
|
rank=rank,
|
|
|
|
world_size=world_size,
|
2023-04-12 09:18:08 -06:00
|
|
|
)
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def load_weights(
|
|
|
|
model,
|
|
|
|
filenames: List[str],
|
2023-05-15 15:36:30 -06:00
|
|
|
quantize: Optional[str],
|
2023-04-12 09:18:08 -06:00
|
|
|
device: torch.device,
|
|
|
|
dtype: torch.dtype,
|
|
|
|
rank: int,
|
|
|
|
world_size: int,
|
|
|
|
transpose: bool,
|
|
|
|
):
|
|
|
|
for file in filenames:
|
2023-04-19 04:51:11 -06:00
|
|
|
with safe_open(
|
2023-05-15 15:36:30 -06:00
|
|
|
file, framework="pt", device=str(device) if quantize is None else "cpu"
|
2023-04-19 04:51:11 -06:00
|
|
|
) as f:
|
2023-04-12 09:18:08 -06:00
|
|
|
for key in f.keys():
|
|
|
|
slice_ = f.get_slice(key)
|
|
|
|
|
|
|
|
layer_name = ".".join(key.split(".")[:4])
|
|
|
|
|
|
|
|
# Fused qkv
|
|
|
|
if "q_attn.weight" in key or "kv_attn.weight" in key:
|
|
|
|
final_key = layer_name + ".c_attn.weight"
|
|
|
|
elif "q_attn.bias" in key or "kv_attn.bias" in key:
|
|
|
|
final_key = layer_name + ".c_attn.bias"
|
|
|
|
else:
|
|
|
|
final_key = key
|
|
|
|
|
|
|
|
module_name, param_name = final_key.rsplit(".", 1)
|
|
|
|
module = model.get_submodule(module_name)
|
|
|
|
|
|
|
|
if isinstance(module, TensorParallelColumnLinear):
|
|
|
|
dim = 1 if transpose and "weight" in param_name else 0
|
|
|
|
size = slice_.get_shape()[dim]
|
|
|
|
block_size = size // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
tensor = (
|
|
|
|
slice_[start:stop] if dim == 0 else slice_[:, start:stop]
|
|
|
|
)
|
|
|
|
elif isinstance(module, TensorParallelRowLinear):
|
|
|
|
if param_name == "weight":
|
|
|
|
dim = 0 if transpose else 1
|
|
|
|
size = slice_.get_shape()[dim]
|
|
|
|
block_size = size // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
tensor = (
|
|
|
|
slice_[start:stop]
|
|
|
|
if dim == 0
|
|
|
|
else slice_[:, start:stop]
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
tensor = slice_[:]
|
|
|
|
# XXX: Hack for Rowlinear to add the bias only once.
|
|
|
|
if rank != 0:
|
|
|
|
tensor = torch.zeros_like(tensor)
|
|
|
|
elif isinstance(module, TensorParallelEmbedding):
|
|
|
|
size = slice_.get_shape()[0]
|
|
|
|
block_size = size // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
tensor = slice_[start:stop]
|
|
|
|
elif key == "lm_head.weight" and model.transformer.tp_embeddings:
|
|
|
|
size = slice_.get_shape()[0]
|
|
|
|
block_size = size // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
tensor = slice_[start:stop]
|
|
|
|
else:
|
|
|
|
try:
|
|
|
|
tensor = slice_[:]
|
|
|
|
except:
|
|
|
|
tensor = f.get_tensor(key)
|
|
|
|
|
|
|
|
tensor = tensor.contiguous().to(dtype)
|
|
|
|
|
|
|
|
try:
|
|
|
|
current_parameter_tensor = module._parameters[param_name]
|
|
|
|
except KeyError:
|
|
|
|
current_parameter_tensor = None
|
|
|
|
|
|
|
|
if current_parameter_tensor is not None:
|
|
|
|
if transpose and (
|
|
|
|
"c_fc.weight" in key
|
|
|
|
or "c_proj.weight" in key
|
|
|
|
or "q_attn.weight" in key
|
|
|
|
or "kv_attn.weight" in key
|
|
|
|
or "c_attn.weight" in key
|
|
|
|
):
|
|
|
|
# Tranpose as we use nn.Linear instead of Conv1D
|
|
|
|
tensor = tensor.T
|
|
|
|
|
|
|
|
if current_parameter_tensor.device == torch.device("meta"):
|
|
|
|
# Init qkv
|
|
|
|
if "c_attn.weight" in final_key:
|
|
|
|
module._parameters[param_name] = tensor.new_empty(
|
|
|
|
(
|
|
|
|
model.transformer.head_size
|
|
|
|
* (model.transformer.num_heads + 2),
|
|
|
|
tensor.shape[1],
|
|
|
|
)
|
|
|
|
)
|
|
|
|
elif "c_attn.bias" in final_key:
|
|
|
|
module._parameters[param_name] = tensor.new_empty(
|
|
|
|
(
|
|
|
|
model.transformer.head_size
|
|
|
|
* (model.transformer.num_heads + 2)
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Copy to correct slice
|
|
|
|
if "q_attn" in key:
|
|
|
|
size = tensor.shape[0]
|
|
|
|
block_size = size // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
tensor = tensor[start:stop]
|
|
|
|
module._parameters[param_name][: tensor.shape[0]] = tensor
|
|
|
|
elif "kv_attn.weight" in key:
|
|
|
|
module._parameters[param_name][
|
|
|
|
model.transformer.head_size
|
|
|
|
* model.transformer.num_heads :
|
|
|
|
] = tensor
|
|
|
|
elif "kv_attn.bias" in key:
|
|
|
|
module._parameters[param_name][
|
|
|
|
model.transformer.head_size
|
|
|
|
* model.transformer.num_heads :
|
|
|
|
] = tensor
|
|
|
|
elif "c_attn" in key:
|
|
|
|
# Slice q_tensor by shard
|
|
|
|
q_tensor = tensor[: -2 * model.transformer.head_size]
|
|
|
|
block_size = q_tensor.shape[0] // world_size
|
|
|
|
start = rank * block_size
|
|
|
|
stop = (rank + 1) * block_size
|
|
|
|
q_tensor = q_tensor[start:stop]
|
|
|
|
|
|
|
|
module._parameters[param_name][
|
|
|
|
: q_tensor.shape[0]
|
|
|
|
] = q_tensor
|
|
|
|
|
|
|
|
# Kv tensor is copied for every shard
|
|
|
|
kv_tensor = tensor[-2 * model.transformer.head_size :]
|
|
|
|
module._parameters[param_name][
|
|
|
|
q_tensor.shape[0] :
|
|
|
|
] = kv_tensor
|
|
|
|
else:
|
|
|
|
if current_parameter_tensor.shape != tensor.shape:
|
|
|
|
raise ValueError(
|
|
|
|
f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
|
|
|
|
)
|
|
|
|
|
|
|
|
module._parameters[param_name] = tensor
|
|
|
|
else:
|
|
|
|
module._buffers[param_name] = tensor
|
2023-05-03 03:36:24 -06:00
|
|
|
|
2023-05-15 02:35:20 -06:00
|
|
|
model.lm_head.weight = torch.nn.Parameter(model.transformer.wte.weight)
|
2023-04-12 09:18:08 -06:00
|
|
|
torch.cuda.empty_cache()
|
2023-04-19 04:51:11 -06:00
|
|
|
model.post_load_weights(quantize)
|