diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index 827e47df..4d68109a 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -38,9 +38,12 @@ class GPTQMarlinFP8Linear(nn.Module): log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") + # if scales is a scalar (0D tensor), convert it to a 1D tensor + if scales.dim() == 0: + scales = scales.unsqueeze(0) + scales = scales.unsqueeze(0) - # repack weights for Marlin if a single scale is provided - if scales.size(0) == 1: + if scales.shape[1] == 1: out_features, in_features = qweight.shape scales = scales.repeat(1, out_features) qweight, scales = repack_fp8_for_marlin(qweight, scales)