fix: fix gpt-q with groupsize = -1 (#1358)
This commit is contained in:
parent
8428ed1011
commit
d077150eb7
|
@ -213,6 +213,9 @@ message DecodeResponse {
|
|||
message WarmupRequest {
|
||||
/// Batch to warmup on
|
||||
Batch batch = 1;
|
||||
uint32 max_input_length = 2;
|
||||
uint32 max_prefill_tokens = 3;
|
||||
uint32 max_total_tokens = 4;
|
||||
}
|
||||
|
||||
/// Empty response
|
||||
|
|
|
@ -145,7 +145,13 @@ impl Client {
|
|||
max_tokens: 0,
|
||||
};
|
||||
|
||||
let request = tonic::Request::new(WarmupRequest { batch: Some(batch) }).inject_context();
|
||||
let request = tonic::Request::new(WarmupRequest {
|
||||
batch: Some(batch),
|
||||
max_input_length,
|
||||
max_prefill_tokens,
|
||||
max_total_tokens,
|
||||
})
|
||||
.inject_context();
|
||||
let response = self.stub.warmup(request).await?.into_inner();
|
||||
Ok(response.max_supported_total_tokens)
|
||||
}
|
||||
|
|
|
@ -19,9 +19,16 @@ from text_generation_server.models.idefics_causal_lm import IdeficsCausalLMBatch
|
|||
|
||||
|
||||
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
||||
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
cache: Cache,
|
||||
quantize: Optional[str],
|
||||
server_urls: List[str],
|
||||
):
|
||||
self.cache = cache
|
||||
self.model = model
|
||||
self.quantize = quantize
|
||||
self.server_urls = server_urls
|
||||
# For some reason, inference_mode does not work well with GLOO which we use on CPU
|
||||
if model.device.type == "cuda":
|
||||
|
@ -56,6 +63,21 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
|
|||
return generate_pb2.FilterBatchResponse(batch=filtered_batch.to_pb())
|
||||
|
||||
async def Warmup(self, request, context):
|
||||
if self.quantize == "gptq":
|
||||
try:
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
# For which we have the finale shapes only after the model has loaded
|
||||
# This will allocate those buffers.
|
||||
from text_generation_server.utils.layers import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
||||
set_device(self.model.device)
|
||||
create_exllama_buffers(request.max_prefill_tokens)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if (
|
||||
self.model.batch_type == IdeficsCausalLMBatch
|
||||
): # Hack, i would rather use kwargs in the `from_pb` call
|
||||
|
@ -184,21 +206,6 @@ def serve(
|
|||
logger.exception("Error when initializing model")
|
||||
raise
|
||||
|
||||
if quantize == "gptq":
|
||||
try:
|
||||
# When using GPTQ, Exllama kernels need some global kernels
|
||||
# For which we have the finale shapes only after the model has loaded
|
||||
# This will allocate those buffers.
|
||||
from text_generation_server.utils.layers import (
|
||||
create_exllama_buffers,
|
||||
set_device,
|
||||
)
|
||||
|
||||
set_device(model.device)
|
||||
create_exllama_buffers()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
server = aio.server(
|
||||
interceptors=[
|
||||
ExceptionInterceptor(),
|
||||
|
@ -206,7 +213,7 @@ def serve(
|
|||
]
|
||||
)
|
||||
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
|
||||
TextGenerationService(model, Cache(), server_urls), server
|
||||
TextGenerationService(model, Cache(), quantize, server_urls), server
|
||||
)
|
||||
SERVICE_NAMES = (
|
||||
generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
|
||||
|
|
|
@ -37,19 +37,12 @@ def set_device(device):
|
|||
DEVICE = device
|
||||
|
||||
|
||||
def create_exllama_buffers():
|
||||
def create_exllama_buffers(max_total_tokens: int):
|
||||
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
|
||||
|
||||
assert DEVICE is not None, "call set_device first"
|
||||
|
||||
if ACT_ORDER:
|
||||
# TODO: this should be set to rust side `max_total_tokens`, but TGI
|
||||
# does not offer an API to expose this variable to python, as this variable
|
||||
# is handled by the client but it appears the model is initialized by the server.
|
||||
# An alternative could be to initialize the buffers during warmup.
|
||||
# Dummy
|
||||
max_total_tokens = 2048
|
||||
else:
|
||||
if not ACT_ORDER:
|
||||
max_total_tokens = 1
|
||||
|
||||
# This temp_state buffer is required to reorder X in the act-order case.
|
||||
|
|
|
@ -101,7 +101,7 @@ def set_device(device):
|
|||
DEVICE = device
|
||||
|
||||
|
||||
def create_exllama_buffers():
|
||||
def create_exllama_buffers(max_total_tokens: int):
|
||||
global FIXED_BYTES, LAYERS, DEVICE
|
||||
temp_dq = ExLlamaV2DeviceTensors(DEVICE, FIXED_BYTES)
|
||||
|
||||
|
@ -138,17 +138,6 @@ class QuantLinear(nn.Module):
|
|||
self.bias = bias if bias is not None else None
|
||||
self.group_size = groupsize
|
||||
|
||||
infeatures = self.infeatures
|
||||
outfeatures = self.outfeatures
|
||||
assert qweight.shape == (infeatures // 32 * self.bits, outfeatures)
|
||||
assert infeatures % self.group_size == 0
|
||||
assert qzeros.shape == (
|
||||
infeatures // self.group_size,
|
||||
outfeatures // 32 * self.bits,
|
||||
)
|
||||
assert scales.shape == (infeatures // self.group_size, outfeatures)
|
||||
assert g_idx.shape == (infeatures,), f"{g_idx.shape}, {infeatures}"
|
||||
|
||||
global FIXED_BYTES, LAYERS
|
||||
FIXED_BYTES = max(FIXED_BYTES, self.scratch_space_fixed())
|
||||
LAYERS.append(self)
|
||||
|
|
|
@ -281,18 +281,18 @@ class Weights:
|
|||
else:
|
||||
logger.info(f"Using exllama kernels v{HAS_EXLLAMA}")
|
||||
|
||||
if use_exllama:
|
||||
if use_exllama and groupsize != -1:
|
||||
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
|
||||
scales = self.get_sharded(f"{prefix}.scales", dim=0)
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
g_idx = g_idx - g_idx[0]
|
||||
else:
|
||||
# The triton kernel reorders the scales/zero points instead of the weight/activation.
|
||||
# Thus, each rank needs the full qzeros/scales.
|
||||
qzeros = self.get_tensor(f"{prefix}.qzeros")
|
||||
scales = self.get_tensor(f"{prefix}.scales")
|
||||
|
||||
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
|
||||
|
||||
if use_exllama:
|
||||
g_idx = g_idx - g_idx[0]
|
||||
|
||||
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
|
||||
elif quantize == "awq":
|
||||
bits, groupsize = self._get_gptq_params()
|
||||
|
|
Loading…
Reference in New Issue