parent
2475aede61
commit
e14ae3b5e9
|
@ -26,7 +26,9 @@ try:
|
|||
|
||||
FLASH_ATTENTION = torch.cuda.is_available()
|
||||
except ImportError:
|
||||
logger.opt(exception=True).warning("Could not import Flash Attention enabled models")
|
||||
logger.opt(exception=True).warning(
|
||||
"Could not import Flash Attention enabled models"
|
||||
)
|
||||
FLASH_ATTENTION = False
|
||||
|
||||
__all__ = [
|
||||
|
@ -88,10 +90,10 @@ def get_model(
|
|||
raise NotImplementedError(
|
||||
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
|
||||
)
|
||||
return FlashSantacoderSharded(model_id, revision=revision)
|
||||
return FlashSantacoderSharded(model_id, revision, quantize=quantize)
|
||||
else:
|
||||
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
|
||||
return santacoder_cls(model_id, revision, quantize)
|
||||
return santacoder_cls(model_id, revision, quantize=quantize)
|
||||
|
||||
config = AutoConfig.from_pretrained(model_id, revision=revision)
|
||||
model_type = config.model_type
|
||||
|
|
|
@ -33,6 +33,12 @@ import dropout_layer_norm
|
|||
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
except ImportError as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
|
||||
class LlamaRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
|
@ -94,11 +100,41 @@ class FastLinear(nn.Linear):
|
|||
dtype=None,
|
||||
) -> None:
|
||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||
self.quantized = False
|
||||
self.bnb_linear = None
|
||||
|
||||
def transpose_weight(self):
|
||||
def prepare_weights(self, quantize: bool = False):
|
||||
if quantize:
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
self.quantized = True
|
||||
self.bnb_linear = Linear8bitLt(
|
||||
self.in_features,
|
||||
self.out_features,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
bias=False,
|
||||
)
|
||||
# Copy data to bnb_linear
|
||||
self.bnb_linear.weight.data = self.weight.data
|
||||
if self.bias is not None:
|
||||
self.bnb_linear.bias = nn.Parameter(self.bias)
|
||||
|
||||
# Delete reference to data
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
else:
|
||||
self.weight = nn.Parameter(self.weight.T)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantized:
|
||||
return self.bnb_linear(input)
|
||||
else:
|
||||
if self.bias is not None:
|
||||
return torch.addmm(self.bias, input, self.weight)
|
||||
return torch.matmul(input, self.weight)
|
||||
|
@ -502,15 +538,15 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
self.head_size = self.layers[0].self_attn.head_size
|
||||
self.num_heads = self.layers[0].self_attn.num_heads
|
||||
|
||||
def post_load_weights(self):
|
||||
def post_load_weights(self, load_in_8bit: bool = False):
|
||||
if isinstance(self.embed_tokens, TensorParallelEmbedding):
|
||||
self.embed_tokens.add_null_idx()
|
||||
for layer in self.layers:
|
||||
layer: FlashLlamaLayer
|
||||
layer.self_attn.query_key_value.transpose_weight()
|
||||
layer.self_attn.o_proj.transpose_weight()
|
||||
layer.mlp.gate_up_proj.transpose_weight()
|
||||
layer.mlp.down_proj.transpose_weight()
|
||||
layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
|
||||
layer.self_attn.o_proj.prepare_weights(load_in_8bit)
|
||||
layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
|
||||
layer.mlp.down_proj.prepare_weights(load_in_8bit)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -592,9 +628,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
else:
|
||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def post_load_weights(self):
|
||||
self.model.post_load_weights()
|
||||
self.lm_head.transpose_weight()
|
||||
def post_load_weights(self, load_in_8bit: bool = False):
|
||||
self.model.post_load_weights(load_in_8bit)
|
||||
self.lm_head.prepare_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -35,6 +35,12 @@ import dropout_layer_norm
|
|||
|
||||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
except ImportError as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
|
@ -82,11 +88,41 @@ class FastLinear(nn.Linear):
|
|||
dtype=None,
|
||||
) -> None:
|
||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||
self.quantized = False
|
||||
self.bnb_linear = None
|
||||
|
||||
def transpose_weight(self):
|
||||
def prepare_weights(self, quantize: bool = False):
|
||||
if quantize:
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
self.quantized = True
|
||||
self.bnb_linear = Linear8bitLt(
|
||||
self.in_features,
|
||||
self.out_features,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
bias=False,
|
||||
)
|
||||
# Copy data to bnb_linear
|
||||
self.bnb_linear.weight.data = self.weight.data
|
||||
if self.bias is not None:
|
||||
self.bnb_linear.bias = nn.Parameter(self.bias)
|
||||
|
||||
# Delete reference to data
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
else:
|
||||
self.weight = nn.Parameter(self.weight.T)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantized:
|
||||
return self.bnb_linear(input)
|
||||
else:
|
||||
if self.bias is not None:
|
||||
return torch.addmm(self.bias, input, self.weight)
|
||||
return torch.matmul(input, self.weight)
|
||||
|
@ -552,23 +588,27 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
self.head_size = self.layers[0].attention.head_size
|
||||
self.num_heads = self.layers[0].attention.num_heads
|
||||
|
||||
def post_load_weights(self):
|
||||
def post_load_weights(self, load_in_8bit=False):
|
||||
if isinstance(self.embed_in, TensorParallelEmbedding):
|
||||
self.embed_in.add_null_idx()
|
||||
for layer in self.layers:
|
||||
layer: FlashNeoXLayer
|
||||
layer.attention.shuffle_qkv_dims()
|
||||
layer.attention.query_key_value.transpose_weight()
|
||||
layer.attention.dense.transpose_weight()
|
||||
layer.mlp.dense_h_to_4h.transpose_weight()
|
||||
layer.mlp.dense_4h_to_h.transpose_weight()
|
||||
layer.attention.query_key_value.prepare_weights(load_in_8bit)
|
||||
layer.attention.dense.prepare_weights(load_in_8bit)
|
||||
layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
|
||||
layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# to do it for us
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
model = super(FlashGPTNeoXModel, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
)
|
||||
model.post_load_weights()
|
||||
|
||||
model.post_load_weights(load_in_8bit)
|
||||
return model
|
||||
|
||||
def forward(
|
||||
|
@ -653,16 +693,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
|
|||
config.hidden_size, config.vocab_size, bias=False
|
||||
)
|
||||
|
||||
def post_load_weights(self):
|
||||
self.gpt_neox.post_load_weights()
|
||||
self.embed_out.transpose_weight()
|
||||
def post_load_weights(self, load_in_8bit=False):
|
||||
self.gpt_neox.post_load_weights(load_in_8bit)
|
||||
self.embed_out.prepare_weights()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
|
||||
# to do it for us
|
||||
load_in_8bit = kwargs.pop("load_in_8bit", False)
|
||||
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
|
||||
pretrained_model_name_or_path, *model_args, **kwargs
|
||||
pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
|
||||
)
|
||||
model.post_load_weights()
|
||||
model.post_load_weights(load_in_8bit)
|
||||
return model
|
||||
|
||||
def forward(
|
||||
|
|
|
@ -10,6 +10,12 @@ from transformers.activations import ACT2FN
|
|||
import flash_attn_cuda
|
||||
import dropout_layer_norm
|
||||
|
||||
HAS_BITS_AND_BYTES = True
|
||||
try:
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
except ImportError as e:
|
||||
HAS_BITS_AND_BYTES = False
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
|
@ -57,11 +63,41 @@ class FastLinear(nn.Linear):
|
|||
dtype=None,
|
||||
) -> None:
|
||||
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
|
||||
self.quantized = False
|
||||
self.bnb_linear = None
|
||||
|
||||
def transpose_weight(self):
|
||||
def prepare_weights(self, quantize: bool = False):
|
||||
if quantize:
|
||||
if not HAS_BITS_AND_BYTES:
|
||||
raise ImportError(
|
||||
"bitsandbytes is not available on your machine either because it is not installed "
|
||||
"or you don't have a GPU.\n"
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
self.quantized = True
|
||||
self.bnb_linear = Linear8bitLt(
|
||||
self.in_features,
|
||||
self.out_features,
|
||||
has_fp16_weights=False,
|
||||
threshold=6.0,
|
||||
bias=False,
|
||||
)
|
||||
# Copy data to bnb_linear
|
||||
self.bnb_linear.weight.data = self.weight.data
|
||||
if self.bias is not None:
|
||||
self.bnb_linear.bias = nn.Parameter(self.bias)
|
||||
|
||||
# Delete reference to data
|
||||
self.weight = None
|
||||
self.bias = None
|
||||
else:
|
||||
self.weight = nn.Parameter(self.weight.T)
|
||||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
if self.quantized:
|
||||
return self.bnb_linear(input)
|
||||
else:
|
||||
if self.bias is not None:
|
||||
return torch.addmm(self.bias, input, self.weight)
|
||||
return torch.matmul(input, self.weight)
|
||||
|
@ -431,16 +467,16 @@ class FlashSantacoderModel(nn.Module):
|
|||
self.head_size = self.h[0].attn.head_size
|
||||
self.num_heads = self.h[0].attn.num_heads
|
||||
|
||||
def post_load_weights(self):
|
||||
def post_load_weights(self, load_in_8bit: bool = False):
|
||||
if self.tp_embeddings:
|
||||
self.wte.add_null_idx()
|
||||
self.wpe.add_null_idx()
|
||||
for layer in self.h:
|
||||
layer: Block
|
||||
layer.attn.c_attn.transpose_weight()
|
||||
layer.attn.c_proj.transpose_weight()
|
||||
layer.mlp.c_fc.transpose_weight()
|
||||
layer.mlp.c_proj.transpose_weight()
|
||||
layer.attn.c_attn.prepare_weights(load_in_8bit)
|
||||
layer.attn.c_proj.prepare_weights(load_in_8bit)
|
||||
layer.mlp.c_fc.prepare_weights(load_in_8bit)
|
||||
layer.mlp.c_proj.prepare_weights(load_in_8bit)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -508,9 +544,9 @@ class FlashSantacoderForCausalLM(nn.Module):
|
|||
else:
|
||||
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def post_load_weights(self):
|
||||
self.transformer.post_load_weights()
|
||||
self.lm_head.transpose_weight()
|
||||
def post_load_weights(self, load_in_8bit: bool = False):
|
||||
self.transformer.post_load_weights(load_in_8bit)
|
||||
self.lm_head.prepare_weights()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -221,9 +221,6 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
raise NotImplementedError("FlashCausalLM is only available on GPU")
|
||||
|
||||
if quantize:
|
||||
raise NotImplementedError("FlashCausalLM does not support quantization")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
)
|
||||
|
@ -232,9 +229,10 @@ class FlashCausalLM(Model):
|
|||
model_id,
|
||||
revision=revision,
|
||||
torch_dtype=dtype,
|
||||
load_in_8bit=quantize,
|
||||
)
|
||||
.eval()
|
||||
.cuda()
|
||||
.to(device)
|
||||
)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
|
|
|
@ -35,9 +35,6 @@ class FlashLlama(FlashCausalLM):
|
|||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
if quantize:
|
||||
raise NotImplementedError("FlashLlama does not support quantization")
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
|
@ -61,8 +58,8 @@ class FlashLlama(FlashCausalLM):
|
|||
with init_empty_weights():
|
||||
model = FlashLlamaForCausalLM(config)
|
||||
|
||||
self.load_weights(model, filenames, device, dtype)
|
||||
self.model = model.eval()
|
||||
self.load_weights(model, filenames, quantize, device, dtype)
|
||||
self.model = model.eval().to(device)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
|
@ -73,13 +70,14 @@ class FlashLlama(FlashCausalLM):
|
|||
def load_weights(
|
||||
model,
|
||||
filenames: List[Path],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
):
|
||||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
for key, value in state_dict.items():
|
||||
value = value.to(device).to(dtype)
|
||||
value = value.to(device if not quantize else "cpu").to(dtype)
|
||||
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
|
||||
|
@ -139,7 +137,7 @@ class FlashLlama(FlashCausalLM):
|
|||
del value
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
|
||||
class FlashLlamaSharded(FlashLlama):
|
||||
|
@ -154,9 +152,6 @@ class FlashLlamaSharded(FlashLlama):
|
|||
else:
|
||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||
|
||||
if quantize:
|
||||
raise NotImplementedError("FlashLlama does not support quantization")
|
||||
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
model_id,
|
||||
revision=revision,
|
||||
|
@ -185,7 +180,7 @@ class FlashLlamaSharded(FlashLlama):
|
|||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
self.model = model.eval()
|
||||
self.model = model.eval().to(device)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
|
@ -300,4 +295,4 @@ class FlashLlamaSharded(FlashLlama):
|
|||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights()
|
||||
model.post_load_weights(quantize)
|
||||
|
|
|
@ -41,9 +41,6 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
else:
|
||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||
|
||||
if quantize:
|
||||
raise NotImplementedError("FlashNeoX does not support quantization")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
)
|
||||
|
@ -63,13 +60,13 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
)
|
||||
model.post_load_weights()
|
||||
self.model = model.eval()
|
||||
self.model = model.eval().to(device)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
|
@ -80,6 +77,7 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
|
@ -87,7 +85,9 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
):
|
||||
parameters = dict(model.named_parameters())
|
||||
for file in filenames:
|
||||
with safe_open(file, framework="pt", device=str(device)) as f:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
) as f:
|
||||
for name in f.keys():
|
||||
module_name, param_name = name.rsplit(".", 1)
|
||||
module = model.get_submodule(module_name)
|
||||
|
@ -146,3 +146,4 @@ class FlashNeoXSharded(FlashNeoX):
|
|||
module._parameters[param_name] = tensor
|
||||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
model.post_load_weights(quantize)
|
||||
|
|
|
@ -34,9 +34,6 @@ class FlashSantacoder(FlashCausalLM):
|
|||
else:
|
||||
raise NotImplementedError("FlashSantacoder is only available on GPU")
|
||||
|
||||
if quantize:
|
||||
raise NotImplementedError("FlashSantacoder does not support quantization")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
)
|
||||
|
@ -58,9 +55,14 @@ class FlashSantacoder(FlashCausalLM):
|
|||
model = FlashSantacoderForCausalLM(config)
|
||||
|
||||
self.load_weights(
|
||||
model, filenames, device, dtype, config.architectures[0].startswith("GPT2")
|
||||
model,
|
||||
filenames,
|
||||
quantize,
|
||||
device,
|
||||
dtype,
|
||||
config.architectures[0].startswith("GPT2"),
|
||||
)
|
||||
self.model = model.eval()
|
||||
self.model = model.eval().to(device)
|
||||
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer, device=device, decode_buffer=1
|
||||
|
@ -70,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||
def load_weights(
|
||||
model: FlashSantacoderForCausalLM,
|
||||
filenames: List[Path],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
transpose: bool,
|
||||
|
@ -77,7 +80,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||
for filename in filenames:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
for key, value in state_dict.items():
|
||||
value = value.to(device).to(dtype)
|
||||
value = value.to(device if not quantize else "cpu").to(dtype)
|
||||
|
||||
layer_name = ".".join(key.split(".")[:4])
|
||||
|
||||
|
@ -152,7 +155,7 @@ class FlashSantacoder(FlashCausalLM):
|
|||
del value
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights()
|
||||
model.post_load_weights(quantize)
|
||||
|
||||
def decode(self, generated_ids: List[int]) -> str:
|
||||
# Do not skip special tokens as they are used for custom parsing rules of the generated text
|
||||
|
@ -173,11 +176,6 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
else:
|
||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||
|
||||
if quantize:
|
||||
raise NotImplementedError(
|
||||
"FlashSantacoderSharded does not support quantization"
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_id, revision=revision, padding_side="left", truncation_side="left"
|
||||
)
|
||||
|
@ -197,13 +195,14 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
self.load_weights(
|
||||
model,
|
||||
filenames,
|
||||
quantize=quantize,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
transpose=config.architectures[0].startswith("GPT2"),
|
||||
)
|
||||
self.model = model.eval()
|
||||
self.model = model.eval().to(device)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashCausalLM, self).__init__(
|
||||
tokenizer=tokenizer,
|
||||
|
@ -214,6 +213,7 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
def load_weights(
|
||||
model,
|
||||
filenames: List[str],
|
||||
quantize: bool,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
rank: int,
|
||||
|
@ -221,7 +221,9 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
transpose: bool,
|
||||
):
|
||||
for file in filenames:
|
||||
with safe_open(file, framework="pt", device=str(device)) as f:
|
||||
with safe_open(
|
||||
file, framework="pt", device=str(device) if not quantize else "cpu"
|
||||
) as f:
|
||||
for key in f.keys():
|
||||
slice_ = f.get_slice(key)
|
||||
|
||||
|
@ -363,4 +365,4 @@ class FlashSantacoderSharded(FlashSantacoder):
|
|||
else:
|
||||
module._buffers[param_name] = tensor
|
||||
torch.cuda.empty_cache()
|
||||
model.post_load_weights()
|
||||
model.post_load_weights(quantize)
|
||||
|
|
Loading…
Reference in New Issue