fix(server): fix fp8 weight loading (#2268)

* fix(server): fix fp8 weight loading

* fixed scales loading

* update snap

* revert default dtype
This commit is contained in:
OlivierDehaene 2024-07-22 15:51:32 +00:00 committed by GitHub
parent 6aebf44f47
commit 4844ff790a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 82 additions and 62 deletions

View File

@ -11,12 +11,12 @@
}, },
{ {
"id": 2323, "id": 2323,
"logprob": -9.421875, "logprob": -9.5625,
"text": "Test" "text": "Test"
}, },
{ {
"id": 1715, "id": 1715,
"logprob": -10.546875, "logprob": -10.375,
"text": " request" "text": " request"
} }
], ],
@ -24,66 +24,66 @@
"tokens": [ "tokens": [
{ {
"id": 25, "id": 25,
"logprob": -0.8535156, "logprob": -0.8984375,
"special": false, "special": false,
"text": ":" "text": ":"
}, },
{ {
"id": 2209, "id": 2209,
"logprob": -2.4804688, "logprob": -2.78125,
"special": false, "special": false,
"text": " Is" "text": " Is"
}, },
{ {
"id": 279, "id": 279,
"logprob": -0.7167969, "logprob": -0.6328125,
"special": false, "special": false,
"text": " the" "text": " the"
}, },
{ {
"id": 734, "id": 734,
"logprob": -2.625, "logprob": -2.703125,
"special": false, "special": false,
"text": " function" "text": " function"
}, },
{ {
"id": 330, "id": 330,
"logprob": -0.35131836, "logprob": -0.34179688,
"special": false, "special": false,
"text": " \"" "text": " \""
}, },
{ {
"id": 4110, "id": 4110,
"logprob": -2.4101562, "logprob": -2.359375,
"special": false, "special": false,
"text": "Create" "text": "Create"
}, },
{ {
"id": 264, "id": 7575,
"logprob": -0.23181152, "logprob": -2.1875,
"special": false, "special": false,
"text": " a" "text": "Process"
},
{
"id": 502,
"logprob": -0.25512695,
"special": false,
"text": " new"
},
{
"id": 1052,
"logprob": -1.2792969,
"special": false,
"text": " file"
}, },
{ {
"id": 1, "id": 1,
"logprob": -1.2529297, "logprob": -0.07910156,
"special": false, "special": false,
"text": "\"" "text": "\""
},
{
"id": 304,
"logprob": -0.83203125,
"special": false,
"text": " in"
},
{
"id": 12468,
"logprob": -1.8203125,
"special": false,
"text": " Win"
} }
], ],
"top_tokens": null "top_tokens": null
}, },
"generated_text": "Test request: Is the function \"Create a new file\"" "generated_text": "Test request: Is the function \"CreateProcess\" in Win"
} }

View File

@ -76,7 +76,9 @@ class HybridFP8UnquantLoader(WeightsLoader):
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
# FP8 branch # FP8 branch
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False) scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -102,7 +104,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
# FP8 branch # FP8 branch
scale = weights.get_packed_sharded( scale = weights.get_packed_sharded(
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
) ).reshape(-1)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -115,8 +117,12 @@ class HybridFP8UnquantLoader(WeightsLoader):
return UnquantizedWeight(w) return UnquantizedWeight(w)
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int): def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes] # FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
w = torch.cat(w, dim=dim) w = [
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
]
# Concat then send to the device
w = torch.cat(w, dim=dim).to(weights.device)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
@ -124,7 +130,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False) weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
for p in prefixes for p in prefixes
] ]
scale = torch.cat(scale, dim=0) scale = torch.cat(scale, dim=0).reshape(-1)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,
@ -140,7 +146,9 @@ class HybridFP8UnquantLoader(WeightsLoader):
w = weights.get_sharded(f"{prefix}.weight", dim=1) w = weights.get_sharded(f"{prefix}.weight", dim=1)
# FP8 branch # FP8 branch
if w.dtype == torch.float8_e4m3fn: if w.dtype == torch.float8_e4m3fn:
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False) scale = weights.get_tensor(
f"{prefix}.weight_scale", to_dtype=False
).reshape(-1)
return Fp8Weight( return Fp8Weight(
weight=w, weight=w,
weight_scale=scale, weight_scale=scale,

View File

@ -504,7 +504,7 @@ class GPTQMarlinFP8Linear(nn.Module):
def __init__( def __init__(
self, self,
qweight: torch.Tensor, qweight: torch.Tensor,
scale: torch.Tensor, scales: torch.Tensor,
bias: Optional[torch.Tensor], bias: Optional[torch.Tensor],
) -> None: ) -> None:
super().__init__() super().__init__()
@ -514,8 +514,11 @@ class GPTQMarlinFP8Linear(nn.Module):
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
scale = scale.to(torch.float16) scales = scales.unsqueeze(0)
qweight, scales = repack_fp8_for_marlin(qweight, scale) if scales.shape[1] == 1:
out_features, in_features = qweight.shape
scales = scales.repeat(1, out_features)
qweight, scales = repack_fp8_for_marlin(qweight, scales)
in_features = qweight.shape[0] * MARLIN_TILE_SIZE in_features = qweight.shape[0] * MARLIN_TILE_SIZE
out_features = scales.shape[1] out_features = scales.shape[1]
@ -530,13 +533,13 @@ class GPTQMarlinFP8Linear(nn.Module):
) )
@classmethod @classmethod
def from_unquant(cls, weight, bias, _dtype): def from_unquant(cls, weight, bias, dtype):
qweight, scale = fp8_quantize(weight) qweight, scales = fp8_quantize(weight)
return cls(qweight=qweight, scale=scale, bias=bias) return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
@classmethod @classmethod
def from_fp8(cls, weight, scale, _input_scale, bias, _dtype): def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
return cls(qweight=weight, scale=scale, bias=bias) return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
def forward(self, A: torch.Tensor) -> torch.Tensor: def forward(self, A: torch.Tensor) -> torch.Tensor:
assert marlin_kernels is not None assert marlin_kernels is not None
@ -591,7 +594,7 @@ def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
return packed return packed
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor): def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
""" """
Repack FP8 tensor for GPTQ-Marlin. Repack FP8 tensor for GPTQ-Marlin.
""" """
@ -608,7 +611,6 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
qweight, perm, in_features, out_features, 8 qweight, perm, in_features, out_features, 8
) )
scales = scale.reshape(1, 1).repeat(1, out_features)
scales = permute_scales(scales) scales = permute_scales(scales)
return repacked, scales return repacked, scales
@ -621,7 +623,7 @@ class MarlinWeight(Weight):
Attributes: Attributes:
B (torch.Tensor): int4-quantized weights packed into int32. B (torch.Tensor): int4-quantized weights packed into int32.
s (torch.Tensor): float16 scales. s (torch.Tensor): bfloat16/float16 scales.
""" """
B: torch.Tensor B: torch.Tensor
@ -629,7 +631,7 @@ class MarlinWeight(Weight):
def __post_init__(self): def __post_init__(self):
assert self.B.dtype == torch.int32 assert self.B.dtype == torch.int32
assert self.s.dtype == torch.float16 assert self.s.dtype in [torch.float16, torch.bfloat16]
def get_linear(self, bias: torch.Tensor): def get_linear(self, bias: torch.Tensor):
return MarlinLinear(weight=self, bias=bias) return MarlinLinear(weight=self, bias=bias)

View File

@ -306,14 +306,32 @@ def get_model(
max_input_tokens: int, max_input_tokens: int,
) -> Model: ) -> Model:
global FLASH_ATTENTION global FLASH_ATTENTION
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type", None)
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}:
log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method
elif method == "fbgemm_fp8":
log_master(logger.info, "Auto selecting quantization method fp8")
quantize = "fp8"
else:
log_master(logger.warning, f"Unknown quantization method {method}")
if dtype is None: if dtype is None:
if quantize in ["awq", "exl2", "gptq", "marlin"]: if quantize in ["awq", "exl2", "gptq", "marlin"]:
# These quantizers only work with float16 params. # These quantizers only work with float16 params.
dtype = torch.float16 dtype = torch.float16
elif quantize == "fp8": elif quantize == "fp8":
from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
if FBGEMM_MM_AVAILABLE: if FBGEMM_DYN_AVAILABLE:
# fbgemm kernels are fp8xfp8->bf16 # fbgemm kernels are fp8xfp8->bf16
dtype = torch.bfloat16 dtype = torch.bfloat16
else: else:
@ -332,11 +350,6 @@ def get_model(
else: else:
set_speculate(0) set_speculate(0)
config_dict, _ = PretrainedConfig.get_config_dict(
model_id, revision=revision, trust_remote_code=trust_remote_code
)
model_type = config_dict.get("model_type", None)
speculator = None speculator = None
if "medusa_num_heads" in config_dict: if "medusa_num_heads" in config_dict:
medusa_model_id = model_id medusa_model_id = model_id
@ -451,14 +464,6 @@ def get_model(
raise RuntimeError( raise RuntimeError(
f"Could not determine model type for {model_id} revision {revision}" f"Could not determine model type for {model_id} revision {revision}"
) )
quantization_config = config_dict.get("quantization_config", None)
if quantization_config is not None and quantize is None:
method = quantization_config.get("quant_method", None)
if method in {"gptq", "awq", "exl2"}:
log_master(logger.info, f"Auto selecting quantization method {method}")
quantize = method
else:
log_master(logger.warning, f"Unknown quantization method {method}")
if quantize == "exl2" and sharded: if quantize == "exl2" and sharded:
raise RuntimeError( raise RuntimeError(

View File

@ -230,7 +230,9 @@ class Weights:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True): def get_partial_sharded(
self, tensor_name: str, dim: int, to_device=True, to_dtype=True
):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -256,10 +258,11 @@ class Weights:
and to_dtype and to_dtype
): ):
tensor = tensor.to(dtype=self.dtype) tensor = tensor.to(dtype=self.dtype)
if to_device:
tensor = tensor.to(device=self.device) tensor = tensor.to(device=self.device)
return tensor return tensor
def get_sharded(self, tensor_name: str, dim: int, to_dtype=True): def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):
filename, tensor_name = self.get_filename(tensor_name) filename, tensor_name = self.get_filename(tensor_name)
f = self._get_handle(filename) f = self._get_handle(filename)
slice_ = f.get_slice(tensor_name) slice_ = f.get_slice(tensor_name)
@ -268,7 +271,9 @@ class Weights:
assert ( assert (
size % world_size == 0 size % world_size == 0
), f"The choosen size {size} is not compatible with sharding on {world_size} shards" ), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype) return self.get_partial_sharded(
tensor_name, dim, to_device=to_device, to_dtype=to_dtype
)
def get_packed_sharded( def get_packed_sharded(
self, self,