From 4b10c8c30b62a33aafaea33b91518467729d74b7 Mon Sep 17 00:00:00 2001 From: drbh Date: Wed, 14 Aug 2024 16:38:15 +0000 Subject: [PATCH] fix: improve scales change and revert conditional --- server/text_generation_server/layers/marlin/fp8.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/server/text_generation_server/layers/marlin/fp8.py b/server/text_generation_server/layers/marlin/fp8.py index 827e47df..4d68109a 100644 --- a/server/text_generation_server/layers/marlin/fp8.py +++ b/server/text_generation_server/layers/marlin/fp8.py @@ -38,9 +38,12 @@ class GPTQMarlinFP8Linear(nn.Module): log_once(logger.info, "GPU does not support FP8, using Marlin FP8 kernel") + # if scales is a scalar (0D tensor), convert it to a 1D tensor + if scales.dim() == 0: + scales = scales.unsqueeze(0) + scales = scales.unsqueeze(0) - # repack weights for Marlin if a single scale is provided - if scales.size(0) == 1: + if scales.shape[1] == 1: out_features, in_features = qweight.shape scales = scales.repeat(1, out_features) qweight, scales = repack_fp8_for_marlin(qweight, scales)