fix(server): fix fp8 weight loading (#2268)
* fix(server): fix fp8 weight loading * fixed scales loading * update snap * revert default dtype
This commit is contained in:
parent
6aebf44f47
commit
4844ff790a
|
@ -11,12 +11,12 @@
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2323,
|
"id": 2323,
|
||||||
"logprob": -9.421875,
|
"logprob": -9.5625,
|
||||||
"text": "Test"
|
"text": "Test"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1715,
|
"id": 1715,
|
||||||
"logprob": -10.546875,
|
"logprob": -10.375,
|
||||||
"text": " request"
|
"text": " request"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -24,66 +24,66 @@
|
||||||
"tokens": [
|
"tokens": [
|
||||||
{
|
{
|
||||||
"id": 25,
|
"id": 25,
|
||||||
"logprob": -0.8535156,
|
"logprob": -0.8984375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": ":"
|
"text": ":"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 2209,
|
"id": 2209,
|
||||||
"logprob": -2.4804688,
|
"logprob": -2.78125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " Is"
|
"text": " Is"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 279,
|
"id": 279,
|
||||||
"logprob": -0.7167969,
|
"logprob": -0.6328125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " the"
|
"text": " the"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 734,
|
"id": 734,
|
||||||
"logprob": -2.625,
|
"logprob": -2.703125,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " function"
|
"text": " function"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 330,
|
"id": 330,
|
||||||
"logprob": -0.35131836,
|
"logprob": -0.34179688,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " \""
|
"text": " \""
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 4110,
|
"id": 4110,
|
||||||
"logprob": -2.4101562,
|
"logprob": -2.359375,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "Create"
|
"text": "Create"
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 264,
|
"id": 7575,
|
||||||
"logprob": -0.23181152,
|
"logprob": -2.1875,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": " a"
|
"text": "Process"
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 502,
|
|
||||||
"logprob": -0.25512695,
|
|
||||||
"special": false,
|
|
||||||
"text": " new"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 1052,
|
|
||||||
"logprob": -1.2792969,
|
|
||||||
"special": false,
|
|
||||||
"text": " file"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"id": 1,
|
"id": 1,
|
||||||
"logprob": -1.2529297,
|
"logprob": -0.07910156,
|
||||||
"special": false,
|
"special": false,
|
||||||
"text": "\""
|
"text": "\""
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 304,
|
||||||
|
"logprob": -0.83203125,
|
||||||
|
"special": false,
|
||||||
|
"text": " in"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 12468,
|
||||||
|
"logprob": -1.8203125,
|
||||||
|
"special": false,
|
||||||
|
"text": " Win"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"top_tokens": null
|
"top_tokens": null
|
||||||
},
|
},
|
||||||
"generated_text": "Test request: Is the function \"Create a new file\""
|
"generated_text": "Test request: Is the function \"CreateProcess\" in Win"
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,9 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
|
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_tensor(f"{prefix}.weight_scale", to_dtype=False)
|
scale = weights.get_tensor(
|
||||||
|
f"{prefix}.weight_scale", to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -102,7 +104,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
scale = weights.get_packed_sharded(
|
scale = weights.get_packed_sharded(
|
||||||
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
f"{prefix}.weight_scale", dim=0, block_sizes=block_sizes, to_dtype=False
|
||||||
)
|
).reshape(-1)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -115,8 +117,12 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
return UnquantizedWeight(w)
|
return UnquantizedWeight(w)
|
||||||
|
|
||||||
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
def get_multi_weights_col(self, weights: "Weights", prefixes: List[str], dim: int):
|
||||||
w = [weights.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
|
# FIXME: Force to_device to false as fp8 weights do not support torch.cat on device yet
|
||||||
w = torch.cat(w, dim=dim)
|
w = [
|
||||||
|
weights.get_sharded(f"{p}.weight", dim=0, to_device=False) for p in prefixes
|
||||||
|
]
|
||||||
|
# Concat then send to the device
|
||||||
|
w = torch.cat(w, dim=dim).to(weights.device)
|
||||||
|
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
|
@ -124,7 +130,7 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
weights.get_sharded(f"{p}.weight_scale", dim=0, to_dtype=False)
|
||||||
for p in prefixes
|
for p in prefixes
|
||||||
]
|
]
|
||||||
scale = torch.cat(scale, dim=0)
|
scale = torch.cat(scale, dim=0).reshape(-1)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
@ -140,7 +146,9 @@ class HybridFP8UnquantLoader(WeightsLoader):
|
||||||
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
w = weights.get_sharded(f"{prefix}.weight", dim=1)
|
||||||
# FP8 branch
|
# FP8 branch
|
||||||
if w.dtype == torch.float8_e4m3fn:
|
if w.dtype == torch.float8_e4m3fn:
|
||||||
scale = weights.get_sharded(f"{prefix}.weight_scale", dim=0, to_dtype=False)
|
scale = weights.get_tensor(
|
||||||
|
f"{prefix}.weight_scale", to_dtype=False
|
||||||
|
).reshape(-1)
|
||||||
return Fp8Weight(
|
return Fp8Weight(
|
||||||
weight=w,
|
weight=w,
|
||||||
weight_scale=scale,
|
weight_scale=scale,
|
||||||
|
|
|
@ -504,7 +504,7 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
qweight: torch.Tensor,
|
qweight: torch.Tensor,
|
||||||
scale: torch.Tensor,
|
scales: torch.Tensor,
|
||||||
bias: Optional[torch.Tensor],
|
bias: Optional[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -514,8 +514,11 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||||
|
|
||||||
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel")
|
||||||
|
|
||||||
scale = scale.to(torch.float16)
|
scales = scales.unsqueeze(0)
|
||||||
qweight, scales = repack_fp8_for_marlin(qweight, scale)
|
if scales.shape[1] == 1:
|
||||||
|
out_features, in_features = qweight.shape
|
||||||
|
scales = scales.repeat(1, out_features)
|
||||||
|
qweight, scales = repack_fp8_for_marlin(qweight, scales)
|
||||||
|
|
||||||
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
|
in_features = qweight.shape[0] * MARLIN_TILE_SIZE
|
||||||
out_features = scales.shape[1]
|
out_features = scales.shape[1]
|
||||||
|
@ -530,13 +533,13 @@ class GPTQMarlinFP8Linear(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_unquant(cls, weight, bias, _dtype):
|
def from_unquant(cls, weight, bias, dtype):
|
||||||
qweight, scale = fp8_quantize(weight)
|
qweight, scales = fp8_quantize(weight)
|
||||||
return cls(qweight=qweight, scale=scale, bias=bias)
|
return cls(qweight=qweight, scales=scales.to(dtype), bias=bias)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_fp8(cls, weight, scale, _input_scale, bias, _dtype):
|
def from_fp8(cls, weight, scale, _input_scale, bias, dtype):
|
||||||
return cls(qweight=weight, scale=scale, bias=bias)
|
return cls(qweight=weight, scales=scale.to(dtype), bias=bias)
|
||||||
|
|
||||||
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
def forward(self, A: torch.Tensor) -> torch.Tensor:
|
||||||
assert marlin_kernels is not None
|
assert marlin_kernels is not None
|
||||||
|
@ -591,7 +594,7 @@ def pack_fp8_as_int32(fp8_tensor: torch.Tensor) -> torch.Tensor:
|
||||||
return packed
|
return packed
|
||||||
|
|
||||||
|
|
||||||
def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
|
def repack_fp8_for_marlin(weight: torch.Tensor, scales: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Repack FP8 tensor for GPTQ-Marlin.
|
Repack FP8 tensor for GPTQ-Marlin.
|
||||||
"""
|
"""
|
||||||
|
@ -608,7 +611,6 @@ def repack_fp8_for_marlin(weight: torch.Tensor, scale: torch.Tensor):
|
||||||
qweight, perm, in_features, out_features, 8
|
qweight, perm, in_features, out_features, 8
|
||||||
)
|
)
|
||||||
|
|
||||||
scales = scale.reshape(1, 1).repeat(1, out_features)
|
|
||||||
scales = permute_scales(scales)
|
scales = permute_scales(scales)
|
||||||
|
|
||||||
return repacked, scales
|
return repacked, scales
|
||||||
|
@ -621,7 +623,7 @@ class MarlinWeight(Weight):
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
B (torch.Tensor): int4-quantized weights packed into int32.
|
B (torch.Tensor): int4-quantized weights packed into int32.
|
||||||
s (torch.Tensor): float16 scales.
|
s (torch.Tensor): bfloat16/float16 scales.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
B: torch.Tensor
|
B: torch.Tensor
|
||||||
|
@ -629,7 +631,7 @@ class MarlinWeight(Weight):
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert self.B.dtype == torch.int32
|
assert self.B.dtype == torch.int32
|
||||||
assert self.s.dtype == torch.float16
|
assert self.s.dtype in [torch.float16, torch.bfloat16]
|
||||||
|
|
||||||
def get_linear(self, bias: torch.Tensor):
|
def get_linear(self, bias: torch.Tensor):
|
||||||
return MarlinLinear(weight=self, bias=bias)
|
return MarlinLinear(weight=self, bias=bias)
|
||||||
|
|
|
@ -306,14 +306,32 @@ def get_model(
|
||||||
max_input_tokens: int,
|
max_input_tokens: int,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
global FLASH_ATTENTION
|
global FLASH_ATTENTION
|
||||||
|
|
||||||
|
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||||
|
model_id, revision=revision, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
model_type = config_dict.get("model_type", None)
|
||||||
|
|
||||||
|
quantization_config = config_dict.get("quantization_config", None)
|
||||||
|
if quantization_config is not None and quantize is None:
|
||||||
|
method = quantization_config.get("quant_method", None)
|
||||||
|
if method in {"gptq", "awq", "exl2"}:
|
||||||
|
log_master(logger.info, f"Auto selecting quantization method {method}")
|
||||||
|
quantize = method
|
||||||
|
elif method == "fbgemm_fp8":
|
||||||
|
log_master(logger.info, "Auto selecting quantization method fp8")
|
||||||
|
quantize = "fp8"
|
||||||
|
else:
|
||||||
|
log_master(logger.warning, f"Unknown quantization method {method}")
|
||||||
|
|
||||||
if dtype is None:
|
if dtype is None:
|
||||||
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
if quantize in ["awq", "exl2", "gptq", "marlin"]:
|
||||||
# These quantizers only work with float16 params.
|
# These quantizers only work with float16 params.
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
elif quantize == "fp8":
|
elif quantize == "fp8":
|
||||||
from text_generation_server.layers.fp8 import FBGEMM_MM_AVAILABLE
|
from text_generation_server.layers.fp8 import FBGEMM_DYN_AVAILABLE
|
||||||
|
|
||||||
if FBGEMM_MM_AVAILABLE:
|
if FBGEMM_DYN_AVAILABLE:
|
||||||
# fbgemm kernels are fp8xfp8->bf16
|
# fbgemm kernels are fp8xfp8->bf16
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
else:
|
else:
|
||||||
|
@ -332,11 +350,6 @@ def get_model(
|
||||||
else:
|
else:
|
||||||
set_speculate(0)
|
set_speculate(0)
|
||||||
|
|
||||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
|
||||||
model_id, revision=revision, trust_remote_code=trust_remote_code
|
|
||||||
)
|
|
||||||
model_type = config_dict.get("model_type", None)
|
|
||||||
|
|
||||||
speculator = None
|
speculator = None
|
||||||
if "medusa_num_heads" in config_dict:
|
if "medusa_num_heads" in config_dict:
|
||||||
medusa_model_id = model_id
|
medusa_model_id = model_id
|
||||||
|
@ -451,14 +464,6 @@ def get_model(
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Could not determine model type for {model_id} revision {revision}"
|
f"Could not determine model type for {model_id} revision {revision}"
|
||||||
)
|
)
|
||||||
quantization_config = config_dict.get("quantization_config", None)
|
|
||||||
if quantization_config is not None and quantize is None:
|
|
||||||
method = quantization_config.get("quant_method", None)
|
|
||||||
if method in {"gptq", "awq", "exl2"}:
|
|
||||||
log_master(logger.info, f"Auto selecting quantization method {method}")
|
|
||||||
quantize = method
|
|
||||||
else:
|
|
||||||
log_master(logger.warning, f"Unknown quantization method {method}")
|
|
||||||
|
|
||||||
if quantize == "exl2" and sharded:
|
if quantize == "exl2" and sharded:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
|
@ -230,7 +230,9 @@ class Weights:
|
||||||
tensor = tensor.to(device=self.device)
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_partial_sharded(self, tensor_name: str, dim: int, to_dtype=True):
|
def get_partial_sharded(
|
||||||
|
self, tensor_name: str, dim: int, to_device=True, to_dtype=True
|
||||||
|
):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
|
@ -256,10 +258,11 @@ class Weights:
|
||||||
and to_dtype
|
and to_dtype
|
||||||
):
|
):
|
||||||
tensor = tensor.to(dtype=self.dtype)
|
tensor = tensor.to(dtype=self.dtype)
|
||||||
tensor = tensor.to(device=self.device)
|
if to_device:
|
||||||
|
tensor = tensor.to(device=self.device)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def get_sharded(self, tensor_name: str, dim: int, to_dtype=True):
|
def get_sharded(self, tensor_name: str, dim: int, to_device=True, to_dtype=True):
|
||||||
filename, tensor_name = self.get_filename(tensor_name)
|
filename, tensor_name = self.get_filename(tensor_name)
|
||||||
f = self._get_handle(filename)
|
f = self._get_handle(filename)
|
||||||
slice_ = f.get_slice(tensor_name)
|
slice_ = f.get_slice(tensor_name)
|
||||||
|
@ -268,7 +271,9 @@ class Weights:
|
||||||
assert (
|
assert (
|
||||||
size % world_size == 0
|
size % world_size == 0
|
||||||
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
), f"The choosen size {size} is not compatible with sharding on {world_size} shards"
|
||||||
return self.get_partial_sharded(tensor_name, dim, to_dtype=to_dtype)
|
return self.get_partial_sharded(
|
||||||
|
tensor_name, dim, to_device=to_device, to_dtype=to_dtype
|
||||||
|
)
|
||||||
|
|
||||||
def get_packed_sharded(
|
def get_packed_sharded(
|
||||||
self,
|
self,
|
||||||
|
|
Loading…
Reference in New Issue