feat(server): support quantization for flash models (#200)

closes #197
This commit is contained in:
OlivierDehaene 2023-04-19 12:51:11 +02:00 committed by GitHub
parent 2475aede61
commit e14ae3b5e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 196 additions and 83 deletions

View File

@ -26,7 +26,9 @@ try:
FLASH_ATTENTION = torch.cuda.is_available() FLASH_ATTENTION = torch.cuda.is_available()
except ImportError: except ImportError:
logger.opt(exception=True).warning("Could not import Flash Attention enabled models") logger.opt(exception=True).warning(
"Could not import Flash Attention enabled models"
)
FLASH_ATTENTION = False FLASH_ATTENTION = False
__all__ = [ __all__ = [
@ -88,10 +90,10 @@ def get_model(
raise NotImplementedError( raise NotImplementedError(
FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder") FLASH_ATT_ERROR_MESSAGE.format(f"Sharded Santacoder")
) )
return FlashSantacoderSharded(model_id, revision=revision) return FlashSantacoderSharded(model_id, revision, quantize=quantize)
else: else:
santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder santacoder_cls = FlashSantacoder if FLASH_ATTENTION else SantaCoder
return santacoder_cls(model_id, revision, quantize) return santacoder_cls(model_id, revision, quantize=quantize)
config = AutoConfig.from_pretrained(model_id, revision=revision) config = AutoConfig.from_pretrained(model_id, revision=revision)
model_type = config.model_type model_type = config.model_type

View File

@ -33,6 +33,12 @@ import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
HAS_BITS_AND_BYTES = True
try:
from bitsandbytes.nn import Linear8bitLt
except ImportError as e:
HAS_BITS_AND_BYTES = False
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
@ -94,14 +100,44 @@ class FastLinear(nn.Linear):
dtype=None, dtype=None,
) -> None: ) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.bnb_linear = None
def transpose_weight(self): def prepare_weights(self, quantize: bool = False):
self.weight = nn.Parameter(self.weight.T) if quantize:
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
self.bias = None
else:
self.weight = nn.Parameter(self.weight.T)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.bias is not None: if self.quantized:
return torch.addmm(self.bias, input, self.weight) return self.bnb_linear(input)
return torch.matmul(input, self.weight) else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear): class TensorParallelColumnLinear(FastLinear):
@ -502,15 +538,15 @@ class FlashLlamaModel(torch.nn.Module):
self.head_size = self.layers[0].self_attn.head_size self.head_size = self.layers[0].self_attn.head_size
self.num_heads = self.layers[0].self_attn.num_heads self.num_heads = self.layers[0].self_attn.num_heads
def post_load_weights(self): def post_load_weights(self, load_in_8bit: bool = False):
if isinstance(self.embed_tokens, TensorParallelEmbedding): if isinstance(self.embed_tokens, TensorParallelEmbedding):
self.embed_tokens.add_null_idx() self.embed_tokens.add_null_idx()
for layer in self.layers: for layer in self.layers:
layer: FlashLlamaLayer layer: FlashLlamaLayer
layer.self_attn.query_key_value.transpose_weight() layer.self_attn.query_key_value.prepare_weights(load_in_8bit)
layer.self_attn.o_proj.transpose_weight() layer.self_attn.o_proj.prepare_weights(load_in_8bit)
layer.mlp.gate_up_proj.transpose_weight() layer.mlp.gate_up_proj.prepare_weights(load_in_8bit)
layer.mlp.down_proj.transpose_weight() layer.mlp.down_proj.prepare_weights(load_in_8bit)
def forward( def forward(
self, self,
@ -592,9 +628,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else: else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self): def post_load_weights(self, load_in_8bit: bool = False):
self.model.post_load_weights() self.model.post_load_weights(load_in_8bit)
self.lm_head.transpose_weight() self.lm_head.prepare_weights()
def forward( def forward(
self, self,

View File

@ -35,6 +35,12 @@ import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
HAS_BITS_AND_BYTES = True
try:
from bitsandbytes.nn import Linear8bitLt
except ImportError as e:
HAS_BITS_AND_BYTES = False
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
@ -82,14 +88,44 @@ class FastLinear(nn.Linear):
dtype=None, dtype=None,
) -> None: ) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.bnb_linear = None
def transpose_weight(self): def prepare_weights(self, quantize: bool = False):
self.weight = nn.Parameter(self.weight.T) if quantize:
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
self.bias = None
else:
self.weight = nn.Parameter(self.weight.T)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.bias is not None: if self.quantized:
return torch.addmm(self.bias, input, self.weight) return self.bnb_linear(input)
return torch.matmul(input, self.weight) else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear): class TensorParallelColumnLinear(FastLinear):
@ -552,23 +588,27 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self.head_size = self.layers[0].attention.head_size self.head_size = self.layers[0].attention.head_size
self.num_heads = self.layers[0].attention.num_heads self.num_heads = self.layers[0].attention.num_heads
def post_load_weights(self): def post_load_weights(self, load_in_8bit=False):
if isinstance(self.embed_in, TensorParallelEmbedding): if isinstance(self.embed_in, TensorParallelEmbedding):
self.embed_in.add_null_idx() self.embed_in.add_null_idx()
for layer in self.layers: for layer in self.layers:
layer: FlashNeoXLayer layer: FlashNeoXLayer
layer.attention.shuffle_qkv_dims() layer.attention.shuffle_qkv_dims()
layer.attention.query_key_value.transpose_weight() layer.attention.query_key_value.prepare_weights(load_in_8bit)
layer.attention.dense.transpose_weight() layer.attention.dense.prepare_weights(load_in_8bit)
layer.mlp.dense_h_to_4h.transpose_weight() layer.mlp.dense_h_to_4h.prepare_weights(load_in_8bit)
layer.mlp.dense_4h_to_h.transpose_weight() layer.mlp.dense_4h_to_h.prepare_weights(load_in_8bit)
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXModel, cls).from_pretrained( model = super(FlashGPTNeoXModel, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
) )
model.post_load_weights()
model.post_load_weights(load_in_8bit)
return model return model
def forward( def forward(
@ -653,16 +693,19 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
config.hidden_size, config.vocab_size, bias=False config.hidden_size, config.vocab_size, bias=False
) )
def post_load_weights(self): def post_load_weights(self, load_in_8bit=False):
self.gpt_neox.post_load_weights() self.gpt_neox.post_load_weights(load_in_8bit)
self.embed_out.transpose_weight() self.embed_out.prepare_weights()
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Pop here as we will replace the layer in our own logic and don't want from_pretrained
# to do it for us
load_in_8bit = kwargs.pop("load_in_8bit", False)
model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained( model = super(FlashGPTNeoXForCausalLM, cls).from_pretrained(
pretrained_model_name_or_path, *model_args, **kwargs pretrained_model_name_or_path, load_in_8bit=False, *model_args, **kwargs
) )
model.post_load_weights() model.post_load_weights(load_in_8bit)
return model return model
def forward( def forward(

View File

@ -10,6 +10,12 @@ from transformers.activations import ACT2FN
import flash_attn_cuda import flash_attn_cuda
import dropout_layer_norm import dropout_layer_norm
HAS_BITS_AND_BYTES = True
try:
from bitsandbytes.nn import Linear8bitLt
except ImportError as e:
HAS_BITS_AND_BYTES = False
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
@ -57,14 +63,44 @@ class FastLinear(nn.Linear):
dtype=None, dtype=None,
) -> None: ) -> None:
super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype) super(FastLinear, self).__init__(in_features, out_features, bias, device, dtype)
self.quantized = False
self.bnb_linear = None
def transpose_weight(self): def prepare_weights(self, quantize: bool = False):
self.weight = nn.Parameter(self.weight.T) if quantize:
if not HAS_BITS_AND_BYTES:
raise ImportError(
"bitsandbytes is not available on your machine either because it is not installed "
"or you don't have a GPU.\n"
"You can install it with `pip install bitsandbytes`."
)
self.quantized = True
self.bnb_linear = Linear8bitLt(
self.in_features,
self.out_features,
has_fp16_weights=False,
threshold=6.0,
bias=False,
)
# Copy data to bnb_linear
self.bnb_linear.weight.data = self.weight.data
if self.bias is not None:
self.bnb_linear.bias = nn.Parameter(self.bias)
# Delete reference to data
self.weight = None
self.bias = None
else:
self.weight = nn.Parameter(self.weight.T)
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.bias is not None: if self.quantized:
return torch.addmm(self.bias, input, self.weight) return self.bnb_linear(input)
return torch.matmul(input, self.weight) else:
if self.bias is not None:
return torch.addmm(self.bias, input, self.weight)
return torch.matmul(input, self.weight)
class TensorParallelColumnLinear(FastLinear): class TensorParallelColumnLinear(FastLinear):
@ -431,16 +467,16 @@ class FlashSantacoderModel(nn.Module):
self.head_size = self.h[0].attn.head_size self.head_size = self.h[0].attn.head_size
self.num_heads = self.h[0].attn.num_heads self.num_heads = self.h[0].attn.num_heads
def post_load_weights(self): def post_load_weights(self, load_in_8bit: bool = False):
if self.tp_embeddings: if self.tp_embeddings:
self.wte.add_null_idx() self.wte.add_null_idx()
self.wpe.add_null_idx() self.wpe.add_null_idx()
for layer in self.h: for layer in self.h:
layer: Block layer: Block
layer.attn.c_attn.transpose_weight() layer.attn.c_attn.prepare_weights(load_in_8bit)
layer.attn.c_proj.transpose_weight() layer.attn.c_proj.prepare_weights(load_in_8bit)
layer.mlp.c_fc.transpose_weight() layer.mlp.c_fc.prepare_weights(load_in_8bit)
layer.mlp.c_proj.transpose_weight() layer.mlp.c_proj.prepare_weights(load_in_8bit)
def forward( def forward(
self, self,
@ -508,9 +544,9 @@ class FlashSantacoderForCausalLM(nn.Module):
else: else:
self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = FastLinear(config.hidden_size, config.vocab_size, bias=False)
def post_load_weights(self): def post_load_weights(self, load_in_8bit: bool = False):
self.transformer.post_load_weights() self.transformer.post_load_weights(load_in_8bit)
self.lm_head.transpose_weight() self.lm_head.prepare_weights()
def forward( def forward(
self, self,

View File

@ -221,9 +221,6 @@ class FlashCausalLM(Model):
else: else:
raise NotImplementedError("FlashCausalLM is only available on GPU") raise NotImplementedError("FlashCausalLM is only available on GPU")
if quantize:
raise NotImplementedError("FlashCausalLM does not support quantization")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
@ -232,9 +229,10 @@ class FlashCausalLM(Model):
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
load_in_8bit=quantize,
) )
.eval() .eval()
.cuda() .to(device)
) )
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(

View File

@ -35,9 +35,6 @@ class FlashLlama(FlashCausalLM):
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
if quantize:
raise NotImplementedError("FlashLlama does not support quantization")
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -61,8 +58,8 @@ class FlashLlama(FlashCausalLM):
with init_empty_weights(): with init_empty_weights():
model = FlashLlamaForCausalLM(config) model = FlashLlamaForCausalLM(config)
self.load_weights(model, filenames, device, dtype) self.load_weights(model, filenames, quantize, device, dtype)
self.model = model.eval() self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -73,13 +70,14 @@ class FlashLlama(FlashCausalLM):
def load_weights( def load_weights(
model, model,
filenames: List[Path], filenames: List[Path],
quantize: bool,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
): ):
for filename in filenames: for filename in filenames:
state_dict = torch.load(filename, map_location="cpu") state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items(): for key, value in state_dict.items():
value = value.to(device).to(dtype) value = value.to(device if not quantize else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4]) layer_name = ".".join(key.split(".")[:4])
@ -139,7 +137,7 @@ class FlashLlama(FlashCausalLM):
del value del value
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights() model.post_load_weights(quantize)
class FlashLlamaSharded(FlashLlama): class FlashLlamaSharded(FlashLlama):
@ -154,9 +152,6 @@ class FlashLlamaSharded(FlashLlama):
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")
if quantize:
raise NotImplementedError("FlashLlama does not support quantization")
tokenizer = LlamaTokenizer.from_pretrained( tokenizer = LlamaTokenizer.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
@ -185,7 +180,7 @@ class FlashLlamaSharded(FlashLlama):
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
) )
self.model = model.eval() self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -300,4 +295,4 @@ class FlashLlamaSharded(FlashLlama):
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights() model.post_load_weights(quantize)

View File

@ -41,9 +41,6 @@ class FlashNeoXSharded(FlashNeoX):
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")
if quantize:
raise NotImplementedError("FlashNeoX does not support quantization")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
@ -63,13 +60,13 @@ class FlashNeoXSharded(FlashNeoX):
self.load_weights( self.load_weights(
model, model,
filenames, filenames,
quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
) )
model.post_load_weights() self.model = model.eval().to(device)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -80,6 +77,7 @@ class FlashNeoXSharded(FlashNeoX):
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
rank: int, rank: int,
@ -87,7 +85,9 @@ class FlashNeoXSharded(FlashNeoX):
): ):
parameters = dict(model.named_parameters()) parameters = dict(model.named_parameters())
for file in filenames: for file in filenames:
with safe_open(file, framework="pt", device=str(device)) as f: with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for name in f.keys(): for name in f.keys():
module_name, param_name = name.rsplit(".", 1) module_name, param_name = name.rsplit(".", 1)
module = model.get_submodule(module_name) module = model.get_submodule(module_name)
@ -146,3 +146,4 @@ class FlashNeoXSharded(FlashNeoX):
module._parameters[param_name] = tensor module._parameters[param_name] = tensor
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
model.post_load_weights(quantize)

View File

@ -34,9 +34,6 @@ class FlashSantacoder(FlashCausalLM):
else: else:
raise NotImplementedError("FlashSantacoder is only available on GPU") raise NotImplementedError("FlashSantacoder is only available on GPU")
if quantize:
raise NotImplementedError("FlashSantacoder does not support quantization")
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
@ -58,9 +55,14 @@ class FlashSantacoder(FlashCausalLM):
model = FlashSantacoderForCausalLM(config) model = FlashSantacoderForCausalLM(config)
self.load_weights( self.load_weights(
model, filenames, device, dtype, config.architectures[0].startswith("GPT2") model,
filenames,
quantize,
device,
dtype,
config.architectures[0].startswith("GPT2"),
) )
self.model = model.eval() self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, device=device, decode_buffer=1 tokenizer=tokenizer, device=device, decode_buffer=1
@ -70,6 +72,7 @@ class FlashSantacoder(FlashCausalLM):
def load_weights( def load_weights(
model: FlashSantacoderForCausalLM, model: FlashSantacoderForCausalLM,
filenames: List[Path], filenames: List[Path],
quantize: bool,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
transpose: bool, transpose: bool,
@ -77,7 +80,7 @@ class FlashSantacoder(FlashCausalLM):
for filename in filenames: for filename in filenames:
state_dict = torch.load(filename, map_location="cpu") state_dict = torch.load(filename, map_location="cpu")
for key, value in state_dict.items(): for key, value in state_dict.items():
value = value.to(device).to(dtype) value = value.to(device if not quantize else "cpu").to(dtype)
layer_name = ".".join(key.split(".")[:4]) layer_name = ".".join(key.split(".")[:4])
@ -152,7 +155,7 @@ class FlashSantacoder(FlashCausalLM):
del value del value
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights() model.post_load_weights(quantize)
def decode(self, generated_ids: List[int]) -> str: def decode(self, generated_ids: List[int]) -> str:
# Do not skip special tokens as they are used for custom parsing rules of the generated text # Do not skip special tokens as they are used for custom parsing rules of the generated text
@ -173,11 +176,6 @@ class FlashSantacoderSharded(FlashSantacoder):
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
if quantize:
raise NotImplementedError(
"FlashSantacoderSharded does not support quantization"
)
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
@ -197,13 +195,14 @@ class FlashSantacoderSharded(FlashSantacoder):
self.load_weights( self.load_weights(
model, model,
filenames, filenames,
quantize=quantize,
device=device, device=device,
dtype=dtype, dtype=dtype,
rank=self.rank, rank=self.rank,
world_size=self.world_size, world_size=self.world_size,
transpose=config.architectures[0].startswith("GPT2"), transpose=config.architectures[0].startswith("GPT2"),
) )
self.model = model.eval() self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -214,6 +213,7 @@ class FlashSantacoderSharded(FlashSantacoder):
def load_weights( def load_weights(
model, model,
filenames: List[str], filenames: List[str],
quantize: bool,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
rank: int, rank: int,
@ -221,7 +221,9 @@ class FlashSantacoderSharded(FlashSantacoder):
transpose: bool, transpose: bool,
): ):
for file in filenames: for file in filenames:
with safe_open(file, framework="pt", device=str(device)) as f: with safe_open(
file, framework="pt", device=str(device) if not quantize else "cpu"
) as f:
for key in f.keys(): for key in f.keys():
slice_ = f.get_slice(key) slice_ = f.get_slice(key)
@ -363,4 +365,4 @@ class FlashSantacoderSharded(FlashSantacoder):
else: else:
module._buffers[param_name] = tensor module._buffers[param_name] = tensor
torch.cuda.empty_cache() torch.cuda.empty_cache()
model.post_load_weights() model.post_load_weights(quantize)