hf_text-generation-inference/server/text_generation_server/models/flash_santacoder.py

367 lines
15 KiB
Python
Raw Normal View History

2023-04-03 11:06:42 -06:00
import torch
import torch.distributed
from accelerate import init_empty_weights
from opentelemetry import trace
from safetensors import safe_open
2023-04-03 11:06:42 -06:00
from pathlib import Path
from transformers import AutoTokenizer, GPT2Config
2023-04-03 11:06:42 -06:00
from typing import Optional, List
from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_santacoder_modeling import (
FlashSantacoderForCausalLM,
TensorParallelRowLinear,
TensorParallelColumnLinear,
TensorParallelEmbedding,
2023-04-03 11:06:42 -06:00
)
from text_generation_server.utils import (
initialize_torch_distributed,
2023-04-03 11:06:42 -06:00
weight_files,
download_weights,
weight_hub_files,
LocalEntryNotFoundError,
)
tracer = trace.get_tracer(__name__)
class FlashSantacoder(FlashCausalLM):
def __init__(self, model_id: str, revision: Optional[str] = None, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashSantacoder is only available on GPU")
if quantize:
raise NotImplementedError("FlashSantacoder does not support quantization")
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
2023-04-03 11:06:42 -06:00
)
config = GPT2Config.from_pretrained(
model_id,
revision=revision,
2023-04-03 11:06:42 -06:00
)
# We do not use from_pretrained as we modified the model internal module layout
try:
filenames = weight_files(model_id, revision, ".bin")
# Local files not found
except LocalEntryNotFoundError:
hub_files = weight_hub_files(model_id, revision, ".bin")
filenames = download_weights(hub_files, model_id, revision)
with init_empty_weights():
model = FlashSantacoderForCausalLM(config)
self.load_weights(
model, filenames, device, dtype, config.architectures[0].startswith("GPT2")
2023-04-03 11:06:42 -06:00
)
self.model = model.eval()
2023-04-03 11:06:42 -06:00
super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, device=device, decode_buffer=1
2023-04-03 11:06:42 -06:00
)
@staticmethod
def load_weights(
model: FlashSantacoderForCausalLM,
filenames: List[Path],
device: torch.device,
dtype: torch.dtype,
transpose: bool,
2023-04-03 11:06:42 -06:00
):
for filename in filenames:
state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items():
value = value.to(device).to(dtype)
2023-04-03 11:06:42 -06:00
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight"
2023-04-03 11:06:42 -06:00
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias"
2023-04-03 11:06:42 -06:00
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
2023-04-03 11:06:42 -06:00
# Tranpose as we use nn.Linear instead of Conv1D
value = value.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
2023-04-03 11:06:42 -06:00
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
value.shape[1],
)
2023-04-03 11:06:42 -06:00
)
elif "c_attn.bias" in final_key:
2023-04-03 11:06:42 -06:00
module._parameters[param_name] = value.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
2023-04-03 11:06:42 -06:00
)
# Copy to correct slice
if "q_attn.weight" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "q_attn.bias" in key:
module._parameters[param_name][: value.shape[0]] = value
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads :
2023-04-03 11:06:42 -06:00
] = value
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size * model.transformer.num_heads :
2023-04-03 11:06:42 -06:00
] = value
else:
if current_parameter_tensor.shape != value.shape:
raise ValueError(
f"Name {final_key} -- Current {current_parameter_tensor.shape} and got {value.shape}"
)
module._parameters[param_name] = value
else:
module._buffers[param_name] = value
del value
2023-04-03 11:06:42 -06:00
torch.cuda.empty_cache()
model.post_load_weights()
def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text
return self.tokenizer.decode(
generated_ids, skip_special_tokens=False, cleanup_tokenization_spaces=False
)
class FlashSantacoderSharded(FlashSantacoder):
def __init__(
self, model_id: str, revision: Optional[str] = None, quantize: bool = False
):
self.process_group, self.rank, self.world_size = initialize_torch_distributed()
self.master = self.rank == 0
if torch.cuda.is_available():
device = torch.device(f"cuda:{self.rank}")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
if quantize:
raise NotImplementedError(
"FlashSantacoderSharded does not support quantization"
)
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
config = GPT2Config.from_pretrained(
model_id,
revision=revision,
)
torch.distributed.barrier(group=self.process_group)
filenames = weight_files(model_id, revision=revision, extension=".safetensors")
with init_empty_weights():
model = FlashSantacoderForCausalLM(config, self.process_group)
torch.distributed.barrier(group=self.process_group)
self.load_weights(
model,
filenames,
device=device,
dtype=dtype,
rank=self.rank,
world_size=self.world_size,
transpose=config.architectures[0].startswith("GPT2"),
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
tokenizer=tokenizer,
device=device,
)
@staticmethod
def load_weights(
model,
filenames: List[str],
device: torch.device,
dtype: torch.dtype,
rank: int,
world_size: int,
transpose: bool,
):
for file in filenames:
with safe_open(file, framework="pt", device=str(device)) as f:
for key in f.keys():
slice_ = f.get_slice(key)
layer_name = ".".join(key.split(".")[:4])
# Fused qkv
if "q_attn.weight" in key or "kv_attn.weight" in key:
final_key = layer_name + ".c_attn.weight"
elif "q_attn.bias" in key or "kv_attn.bias" in key:
final_key = layer_name + ".c_attn.bias"
else:
final_key = key
module_name, param_name = final_key.rsplit(".", 1)
module = model.get_submodule(module_name)
if isinstance(module, TensorParallelColumnLinear):
dim = 1 if transpose and "weight" in param_name else 0
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = (
slice_[start:stop] if dim == 0 else slice_[:, start:stop]
)
elif isinstance(module, TensorParallelRowLinear):
if param_name == "weight":
dim = 0 if transpose else 1
size = slice_.get_shape()[dim]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = (
slice_[start:stop]
if dim == 0
else slice_[:, start:stop]
)
else:
tensor = slice_[:]
# XXX: Hack for Rowlinear to add the bias only once.
if rank != 0:
tensor = torch.zeros_like(tensor)
elif isinstance(module, TensorParallelEmbedding):
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
elif key == "lm_head.weight" and model.transformer.tp_embeddings:
size = slice_.get_shape()[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = slice_[start:stop]
else:
try:
tensor = slice_[:]
except:
tensor = f.get_tensor(key)
tensor = tensor.contiguous().to(dtype)
try:
current_parameter_tensor = module._parameters[param_name]
except KeyError:
current_parameter_tensor = None
if current_parameter_tensor is not None:
if transpose and (
"c_fc.weight" in key
or "c_proj.weight" in key
or "q_attn.weight" in key
or "kv_attn.weight" in key
or "c_attn.weight" in key
):
# Tranpose as we use nn.Linear instead of Conv1D
tensor = tensor.T
if current_parameter_tensor.device == torch.device("meta"):
# Init qkv
if "c_attn.weight" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2),
tensor.shape[1],
)
)
elif "c_attn.bias" in final_key:
module._parameters[param_name] = tensor.new_empty(
(
model.transformer.head_size
* (model.transformer.num_heads + 2)
)
)
# Copy to correct slice
if "q_attn" in key:
size = tensor.shape[0]
block_size = size // world_size
start = rank * block_size
stop = (rank + 1) * block_size
tensor = tensor[start:stop]
module._parameters[param_name][: tensor.shape[0]] = tensor
elif "kv_attn.weight" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "kv_attn.bias" in key:
module._parameters[param_name][
model.transformer.head_size
* model.transformer.num_heads :
] = tensor
elif "c_attn" in key:
# Slice q_tensor by shard
q_tensor = tensor[: -2 * model.transformer.head_size]
block_size = q_tensor.shape[0] // world_size
start = rank * block_size
stop = (rank + 1) * block_size
q_tensor = q_tensor[start:stop]
module._parameters[param_name][
: q_tensor.shape[0]
] = q_tensor
# Kv tensor is copied for every shard
kv_tensor = tensor[-2 * model.transformer.head_size :]
module._parameters[param_name][
q_tensor.shape[0] :
] = kv_tensor
else:
if current_parameter_tensor.shape != tensor.shape:
raise ValueError(
f"Name {key} -- Current {current_parameter_tensor.shape} and got {tensor.shape}"
)
module._parameters[param_name] = tensor
else:
module._buffers[param_name] = tensor
torch.cuda.empty_cache()
model.post_load_weights()