Add support for AWQ-quantized Idefics2 (#2233)

Fixes #2036.
This commit is contained in:
Daniël de Kok 2024-07-16 07:58:25 +02:00 committed by GitHub
parent 0ad7f6f87d
commit 06d0e880e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 39 additions and 11 deletions

View File

@ -34,6 +34,7 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
TensorParallelRowLinear,
)
from text_generation_server.utils.weights import DefaultWeightsLoader
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module):
class Idefics2ForConditionalGeneration(nn.Module):
def __init__(self, prefix, config, weights):
super().__init__()
config.vision_config.quantize = config.quantize
config.vision_config.quantize = None
config.vision_config.speculator = config.speculator
config.text_config.quantize = config.quantize
config.text_config.speculator = config.speculator
@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module):
name="text_model",
)
self.dtype = weights.dtype
# The vision and connector models are not quantized.
with weights.use_loader(DefaultWeightsLoader()):
self.vision_model = Idefics2VisionTransformer(
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
prefix=(
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
),
config=vision_config,
weights=weights,
)
quantize = config.quantize
try:
config.quantize = None
self.connector = Idefics2Connector(
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
config=config,
weights=weights,
)
finally:
config.quantize = quantize
self.config = config
self.image_seq_len = config.perceiver_config.resampler_n_latents
self.image_token_id = config.image_token_id

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, List, Optional, Union
from safetensors import safe_open
@ -306,6 +307,20 @@ class Weights:
def get_weights_row(self, prefix: str):
return self.weights_loader.get_weights_row(self, prefix)
@contextmanager
def use_loader(self, weights_loader: WeightsLoader):
"""
This method is a context manager that can be used to use `Weights` with
a different loader for the duration of the context.
"""
old_loader = self.weights_loader
self.weights_loader = weights_loader
try:
yield
finally:
self.weights_loader = old_loader
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
"""