parent
2475aede61
commit
e14ae3b5e9
|
@ -26,7 +26,9 @@ try:
|
||||||
|
|
||||||
FLASH_ATTENTION = torch.cuda.is_available()
|
FLASH_ATTENTION = torch.cuda.is_available()
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.opt(exception=True).warning("Could not import Flash Attention enabled models")
|
logger.opt(exception=True).warning(
|
||||||
|
"Could not import Flash Attention enabled models"
|
||||||
|
)
|
||||||
FLASH_ATTENTION = False
|
FLASH_ATTENTION = False
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -88,10 +90,10 @@ def get_model(
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||||
)
|
)
|
||||||
return FlashSantacoderSharded(model_id, revision=revision)
|
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
||||||
else:
|
else:
|
||||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||||
return santacoder_cls(model_id, revision, quantize)
|
return santacoder_cls(model_id, revision, quantize=quantize)
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||||
model_type = config.model_type
|
model_type = config.model_type
|
||||||
|
|
|
@ -33,6 +33,12 @@ import dropout_layer_norm
|
||||||
|
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
|
|
||||||
|
HAS_BITS_AND_BYTES = True
|
||||||
|
try:
|
||||||
|
from bitsandbytes.nn import Linear8bitLt
|
||||||
|
except ImportError as e:
|
||||||
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
|
|
||||||
class LlamaRMSNorm(nn.Module):
|
class LlamaRMSNorm(nn.Module):
|
||||||
def __init__(self, hidden_size, eps=1e-6):
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
@ -94,14 +100,44 @@ class FastLinear(nn.Linear):
|
||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||||
|
self.quantized = False
|
||||||
|
self.bnb_linear = None
|
||||||
|
|
||||||
def transpose_weight(self):
|
def prepare_weights(self, quantize: bool = False):
|
||||||
self.weight = nn.Parameter(self.weight.T)
|
if quantize:
|
||||||
|
if not HAS_BITS_AND_BYTES:
|
||||||
|
raise ImportError(
|
||||||
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
|
"or you don't have a GPU.\n"
|
||||||
|
"You can install it with `pip install bitsandbytes`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quantized = True
|
||||||
|
self.bnb_linear = Linear8bitLt(
|
||||||
|
self.in_features,
|
||||||
|
self.out_features,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
threshold=6.0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
# Copy data to bnb_linear
|
||||||
|
self.bnb_linear.weight.data = self.weight.data
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bnb_linear.bias = nn.Parameter(self.bias)
|
||||||
|
|
||||||
|
# Delete reference to data
|
||||||
|
self.weight = None
|
||||||
|
self.bias = None
|
||||||
|
else:
|
||||||
|
self.weight = nn.Parameter(self.weight.T)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if self.bias is not None:
|
if self.quantized:
|
||||||
return torch.addmm(self.bias, input, self.weight)
|
return self.bnb_linear(input)
|
||||||
return torch.matmul(input, self.weight)
|
else:
|
||||||
|
if self.bias is not None:
|
||||||
|
return torch.addmm(self.bias, input, self.weight)
|
||||||
|
return torch.matmul(input, self.weight)
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelColumnLinear(FastLinear):
|
class TensorParallelColumnLinear(FastLinear):
|
||||||
|
@ -502,15 +538,15 @@ class FlashLlamaModel(torch.nn.Module):
|
||||||
self.head_size = self.layers[0].self_attn.head_size
|
self.head_size = self.layers[0].self_attn.head_size
|
||||||
self.num_heads = self.layers[0].self_attn.num_heads
|
self.num_heads = self.layers[0].self_attn.num_heads
|
||||||
|
|
||||||
def post_load_weights(self):
|
def post_load_weights(self, load_in_8bit: bool = False):
|
||||||
if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
||||||
self.embed_tokens.add_null_idx()
|
self.embed_tokens.add_null_idx()
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
layer: FlashLlamaLayer
|
layer: FlashLlamaLayer
|
||||||
layer.self_attn.query_key_value.transpose_weight()
|
layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
|
||||||
layer.self_attn.o_proj.transpose_weight()
|
layer.self_attn.o_proj.prepare_weights(load_in_8bit)
|
||||||
layer.mlp.gate_up_proj.transpose_weight()
|
layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
|
||||||
layer.mlp.down_proj.transpose_weight()
|
layer.mlp.down_proj.prepare_weights(load_in_8bit)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -592,9 +628,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def post_load_weights(self):
|
def post_load_weights(self, load_in_8bit: bool = False):
|
||||||
self.model.post_load_weights()
|
self.model.post_load_weights(load_in_8bit)
|
||||||
self.lm_head.transpose_weight()
|
self.lm_head.prepare_weights()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -35,6 +35,12 @@ import dropout_layer_norm
|
||||||
|
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
|
|
||||||
|
HAS_BITS_AND_BYTES = True
|
||||||
|
try:
|
||||||
|
from bitsandbytes.nn import Linear8bitLt
|
||||||
|
except ImportError as e:
|
||||||
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
|
@ -82,14 +88,44 @@ class FastLinear(nn.Linear):
|
||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||||
|
self.quantized = False
|
||||||
|
self.bnb_linear = None
|
||||||
|
|
||||||
def transpose_weight(self):
|
def prepare_weights(self, quantize: bool = False):
|
||||||
self.weight = nn.Parameter(self.weight.T)
|
if quantize:
|
||||||
|
if not HAS_BITS_AND_BYTES:
|
||||||
|
raise ImportError(
|
||||||
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
|
"or you don't have a GPU.\n"
|
||||||
|
"You can install it with `pip install bitsandbytes`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quantized = True
|
||||||
|
self.bnb_linear = Linear8bitLt(
|
||||||
|
self.in_features,
|
||||||
|
self.out_features,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
threshold=6.0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
# Copy data to bnb_linear
|
||||||
|
self.bnb_linear.weight.data = self.weight.data
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bnb_linear.bias = nn.Parameter(self.bias)
|
||||||
|
|
||||||
|
# Delete reference to data
|
||||||
|
self.weight = None
|
||||||
|
self.bias = None
|
||||||
|
else:
|
||||||
|
self.weight = nn.Parameter(self.weight.T)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if self.bias is not None:
|
if self.quantized:
|
||||||
return torch.addmm(self.bias, input, self.weight)
|
return self.bnb_linear(input)
|
||||||
return torch.matmul(input, self.weight)
|
else:
|
||||||
|
if self.bias is not None:
|
||||||
|
return torch.addmm(self.bias, input, self.weight)
|
||||||
|
return torch.matmul(input, self.weight)
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelColumnLinear(FastLinear):
|
class TensorParallelColumnLinear(FastLinear):
|
||||||
|
@ -552,23 +588,27 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
self.head_size = self.layers[0].attention.head_size
|
self.head_size = self.layers[0].attention.head_size
|
||||||
self.num_heads = self.layers[0].attention.num_heads
|
self.num_heads = self.layers[0].attention.num_heads
|
||||||
|
|
||||||
def post_load_weights(self):
|
def post_load_weights(self, load_in_8bit=False):
|
||||||
if isinstance(self.embed_in, TensorParallelEmbedding):
|
if isinstance(self.embed_in, TensorParallelEmbedding):
|
||||||
self.embed_in.add_null_idx()
|
self.embed_in.add_null_idx()
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
layer: FlashNeoXLayer
|
layer: FlashNeoXLayer
|
||||||
layer.attention.shuffle_qkv_dims()
|
layer.attention.shuffle_qkv_dims()
|
||||||
layer.attention.query_key_value.transpose_weight()
|
layer.attention.query_key_value.prepare_weights(load_in_8bit)
|
||||||
layer.attention.dense.transpose_weight()
|
layer.attention.dense.prepare_weights(load_in_8bit)
|
||||||
layer.mlp.dense_h_to_4h.transpose_weight()
|
layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
|
||||||
layer.mlp.dense_4h_to_h.transpose_weight()
|
layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||||
|
# to do it for us
|
||||||
|
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||||
model = super(FlashGPTNeoXModel, cls).from_pretrained(
|
model = super(FlashGPTNeoXModel, cls).from_pretrained(
|
||||||
pretrained_model_name_or_path, *model_args, **kwargs
|
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||||
)
|
)
|
||||||
model.post_load_weights()
|
|
||||||
|
model.post_load_weights(load_in_8bit)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -653,16 +693,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
||||||
config.hidden_size, config.vocab_size, bias=False
|
config.hidden_size, config.vocab_size, bias=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def post_load_weights(self):
|
def post_load_weights(self, load_in_8bit=False):
|
||||||
self.gpt_neox.post_load_weights()
|
self.gpt_neox.post_load_weights(load_in_8bit)
|
||||||
self.embed_out.transpose_weight()
|
self.embed_out.prepare_weights()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||||
|
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||||
|
# to do it for us
|
||||||
|
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||||
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
|
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
|
||||||
pretrained_model_name_or_path, *model_args, **kwargs
|
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||||
)
|
)
|
||||||
model.post_load_weights()
|
model.post_load_weights(load_in_8bit)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|
|
@ -10,6 +10,12 @@ from transformers.activations import ACT2FN
|
||||||
import flash_attn_cuda
|
import flash_attn_cuda
|
||||||
import dropout_layer_norm
|
import dropout_layer_norm
|
||||||
|
|
||||||
|
HAS_BITS_AND_BYTES = True
|
||||||
|
try:
|
||||||
|
from bitsandbytes.nn import Linear8bitLt
|
||||||
|
except ImportError as e:
|
||||||
|
HAS_BITS_AND_BYTES = False
|
||||||
|
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
|
@ -57,14 +63,44 @@ class FastLinear(nn.Linear):
|
||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||||
|
self.quantized = False
|
||||||
|
self.bnb_linear = None
|
||||||
|
|
||||||
def transpose_weight(self):
|
def prepare_weights(self, quantize: bool = False):
|
||||||
self.weight = nn.Parameter(self.weight.T)
|
if quantize:
|
||||||
|
if not HAS_BITS_AND_BYTES:
|
||||||
|
raise ImportError(
|
||||||
|
"bitsandbytes is not available on your machine either because it is not installed "
|
||||||
|
"or you don't have a GPU.\n"
|
||||||
|
"You can install it with `pip install bitsandbytes`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self.quantized = True
|
||||||
|
self.bnb_linear = Linear8bitLt(
|
||||||
|
self.in_features,
|
||||||
|
self.out_features,
|
||||||
|
has_fp16_weights=False,
|
||||||
|
threshold=6.0,
|
||||||
|
bias=False,
|
||||||
|
)
|
||||||
|
# Copy data to bnb_linear
|
||||||
|
self.bnb_linear.weight.data = self.weight.data
|
||||||
|
if self.bias is not None:
|
||||||
|
self.bnb_linear.bias = nn.Parameter(self.bias)
|
||||||
|
|
||||||
|
# Delete reference to data
|
||||||
|
self.weight = None
|
||||||
|
self.bias = None
|
||||||
|
else:
|
||||||
|
self.weight = nn.Parameter(self.weight.T)
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
if self.bias is not None:
|
if self.quantized:
|
||||||
return torch.addmm(self.bias, input, self.weight)
|
return self.bnb_linear(input)
|
||||||
return torch.matmul(input, self.weight)
|
else:
|
||||||
|
if self.bias is not None:
|
||||||
|
return torch.addmm(self.bias, input, self.weight)
|
||||||
|
return torch.matmul(input, self.weight)
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelColumnLinear(FastLinear):
|
class TensorParallelColumnLinear(FastLinear):
|
||||||
|
@ -431,16 +467,16 @@ class FlashSantacoderModel(nn.Module):
|
||||||
self.head_size = self.h[0].attn.head_size
|
self.head_size = self.h[0].attn.head_size
|
||||||
self.num_heads = self.h[0].attn.num_heads
|
self.num_heads = self.h[0].attn.num_heads
|
||||||
|
|
||||||
def post_load_weights(self):
|
def post_load_weights(self, load_in_8bit: bool = False):
|
||||||
if self.tp_embeddings:
|
if self.tp_embeddings:
|
||||||
self.wte.add_null_idx()
|
self.wte.add_null_idx()
|
||||||
self.wpe.add_null_idx()
|
self.wpe.add_null_idx()
|
||||||
for layer in self.h:
|
for layer in self.h:
|
||||||
layer: Block
|
layer: Block
|
||||||
layer.attn.c_attn.transpose_weight()
|
layer.attn.c_attn.prepare_weights(load_in_8bit)
|
||||||
layer.attn.c_proj.transpose_weight()
|
layer.attn.c_proj.prepare_weights(load_in_8bit)
|
||||||
layer.mlp.c_fc.transpose_weight()
|
layer.mlp.c_fc.prepare_weights(load_in_8bit)
|
||||||
layer.mlp.c_proj.transpose_weight()
|
layer.mlp.c_proj.prepare_weights(load_in_8bit)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -508,9 +544,9 @@ class FlashSantacoderForCausalLM(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||||
|
|
||||||
def post_load_weights(self):
|
def post_load_weights(self, load_in_8bit: bool = False):
|
||||||
self.transformer.post_load_weights()
|
self.transformer.post_load_weights(load_in_8bit)
|
||||||
self.lm_head.transpose_weight()
|
self.lm_head.prepare_weights()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -221,9 +221,6 @@ class FlashCausalLM(Model):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashCausalLM is only available on GPU")
|
raise NotImplementedError("FlashCausalLM is only available on GPU")
|
||||||
|
|
||||||
if quantize:
|
|
||||||
raise NotImplementedError("FlashCausalLM does not support quantization")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
|
@ -232,9 +229,10 @@ class FlashCausalLM(Model):
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
|
load_in_8bit=quantize,
|
||||||
)
|
)
|
||||||
.eval()
|
.eval()
|
||||||
.cuda()
|
.to(device)
|
||||||
)
|
)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
|
|
|
@ -35,9 +35,6 @@ class FlashLlama(FlashCausalLM):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
if quantize:
|
|
||||||
raise NotImplementedError("FlashLlama does not support quantization")
|
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
@ -61,8 +58,8 @@ class FlashLlama(FlashCausalLM):
|
||||||
with init_empty_weights():
|
with init_empty_weights():
|
||||||
model = FlashLlamaForCausalLM(config)
|
model = FlashLlamaForCausalLM(config)
|
||||||
|
|
||||||
self.load_weights(model, filenames, device, dtype)
|
self.load_weights(model, filenames, quantize, device, dtype)
|
||||||
self.model = model.eval()
|
self.model = model.eval().to(device)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -73,13 +70,14 @@ class FlashLlama(FlashCausalLM):
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[Path],
|
filenames: List[Path],
|
||||||
|
quantize: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
):
|
):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
state_dict = torch.load(filename, map_location="cpu")
|
state_dict = torch.load(filename, map_location="cpu")
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
value = value.to(device).to(dtype)
|
value = value.to(device if not quantize else "cpu").to(dtype)
|
||||||
|
|
||||||
layer_name = ".".join(key.split(".")[:4])
|
layer_name = ".".join(key.split(".")[:4])
|
||||||
|
|
||||||
|
@ -139,7 +137,7 @@ class FlashLlama(FlashCausalLM):
|
||||||
del value
|
del value
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights()
|
model.post_load_weights(quantize)
|
||||||
|
|
||||||
|
|
||||||
class FlashLlamaSharded(FlashLlama):
|
class FlashLlamaSharded(FlashLlama):
|
||||||
|
@ -154,9 +152,6 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
if quantize:
|
|
||||||
raise NotImplementedError("FlashLlama does not support quantization")
|
|
||||||
|
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(
|
tokenizer = LlamaTokenizer.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
|
@ -185,7 +180,7 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
self.model = model.eval().to(device)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -300,4 +295,4 @@ class FlashLlamaSharded(FlashLlama):
|
||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights()
|
model.post_load_weights(quantize)
|
||||||
|
|
|
@ -41,9 +41,6 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
|
||||||
if quantize:
|
|
||||||
raise NotImplementedError("FlashNeoX does not support quantization")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
|
@ -63,13 +60,13 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
model,
|
model,
|
||||||
filenames,
|
filenames,
|
||||||
|
quantize=quantize,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
)
|
)
|
||||||
model.post_load_weights()
|
self.model = model.eval().to(device)
|
||||||
self.model = model.eval()
|
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -80,6 +77,7 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
|
quantize: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
rank: int,
|
rank: int,
|
||||||
|
@ -87,7 +85,9 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
):
|
):
|
||||||
parameters = dict(model.named_parameters())
|
parameters = dict(model.named_parameters())
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(file, framework="pt", device=str(device)) as f:
|
with safe_open(
|
||||||
|
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||||
|
) as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
module_name, param_name = name.rsplit(".", 1)
|
module_name, param_name = name.rsplit(".", 1)
|
||||||
module = model.get_submodule(module_name)
|
module = model.get_submodule(module_name)
|
||||||
|
@ -146,3 +146,4 @@ class FlashNeoXSharded(FlashNeoX):
|
||||||
module._parameters[param_name] = tensor
|
module._parameters[param_name] = tensor
|
||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
|
model.post_load_weights(quantize)
|
||||||
|
|
|
@ -34,9 +34,6 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoder is only available on GPU")
|
raise NotImplementedError("FlashSantacoder is only available on GPU")
|
||||||
|
|
||||||
if quantize:
|
|
||||||
raise NotImplementedError("FlashSantacoder does not support quantization")
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
|
@ -58,9 +55,14 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
model = FlashSantacoderForCausalLM(config)
|
model = FlashSantacoderForCausalLM(config)
|
||||||
|
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
model, filenames, device, dtype, config.architectures[0].startswith("GPT2")
|
model,
|
||||||
|
filenames,
|
||||||
|
quantize,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
config.architectures[0].startswith("GPT2"),
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
self.model = model.eval().to(device)
|
||||||
|
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||||
|
@ -70,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model: FlashSantacoderForCausalLM,
|
model: FlashSantacoderForCausalLM,
|
||||||
filenames: List[Path],
|
filenames: List[Path],
|
||||||
|
quantize: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
transpose: bool,
|
transpose: bool,
|
||||||
|
@ -77,7 +80,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
state_dict = torch.load(filename, map_location="cpu")
|
state_dict = torch.load(filename, map_location="cpu")
|
||||||
for key, value in state_dict.items():
|
for key, value in state_dict.items():
|
||||||
value = value.to(device).to(dtype)
|
value = value.to(device if not quantize else "cpu").to(dtype)
|
||||||
|
|
||||||
layer_name = ".".join(key.split(".")[:4])
|
layer_name = ".".join(key.split(".")[:4])
|
||||||
|
|
||||||
|
@ -152,7 +155,7 @@ class FlashSantacoder(FlashCausalLM):
|
||||||
del value
|
del value
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights()
|
model.post_load_weights(quantize)
|
||||||
|
|
||||||
def decode(self, generated_ids: List[int]) -> str:
|
def decode(self, generated_ids: List[int]) -> str:
|
||||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||||
|
@ -173,11 +176,6 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
|
||||||
if quantize:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"FlashSantacoderSharded does not support quantization"
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||||
)
|
)
|
||||||
|
@ -197,13 +195,14 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
self.load_weights(
|
self.load_weights(
|
||||||
model,
|
model,
|
||||||
filenames,
|
filenames,
|
||||||
|
quantize=quantize,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
rank=self.rank,
|
rank=self.rank,
|
||||||
world_size=self.world_size,
|
world_size=self.world_size,
|
||||||
transpose=config.architectures[0].startswith("GPT2"),
|
transpose=config.architectures[0].startswith("GPT2"),
|
||||||
)
|
)
|
||||||
self.model = model.eval()
|
self.model = model.eval().to(device)
|
||||||
torch.distributed.barrier(group=self.process_group)
|
torch.distributed.barrier(group=self.process_group)
|
||||||
super(FlashCausalLM, self).__init__(
|
super(FlashCausalLM, self).__init__(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -214,6 +213,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
def load_weights(
|
def load_weights(
|
||||||
model,
|
model,
|
||||||
filenames: List[str],
|
filenames: List[str],
|
||||||
|
quantize: bool,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
rank: int,
|
rank: int,
|
||||||
|
@ -221,7 +221,9 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
transpose: bool,
|
transpose: bool,
|
||||||
):
|
):
|
||||||
for file in filenames:
|
for file in filenames:
|
||||||
with safe_open(file, framework="pt", device=str(device)) as f:
|
with safe_open(
|
||||||
|
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||||
|
) as f:
|
||||||
for key in f.keys():
|
for key in f.keys():
|
||||||
slice_ = f.get_slice(key)
|
slice_ = f.get_slice(key)
|
||||||
|
|
||||||
|
@ -363,4 +365,4 @@ class FlashSantacoderSharded(FlashSantacoder):
|
||||||
else:
|
else:
|
||||||
module._buffers[param_name] = tensor
|
module._buffers[param_name] = tensor
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
model.post_load_weights()
|
model.post_load_weights(quantize)
|
||||||
|
|
Loading…
Reference in New Issue