parent
0ad7f6f87d
commit
06d0e880e0
|
@ -34,6 +34,7 @@ from text_generation_server.layers import (
|
|||
TensorParallelEmbedding,
|
||||
TensorParallelRowLinear,
|
||||
)
|
||||
from text_generation_server.utils.weights import DefaultWeightsLoader
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
|
@ -682,7 +683,7 @@ class Idefics2Connector(nn.Module):
|
|||
class Idefics2ForConditionalGeneration(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
config.vision_config.quantize = config.quantize
|
||||
config.vision_config.quantize = None
|
||||
config.vision_config.speculator = config.speculator
|
||||
config.text_config.quantize = config.quantize
|
||||
config.text_config.speculator = config.speculator
|
||||
|
@ -695,16 +696,28 @@ class Idefics2ForConditionalGeneration(nn.Module):
|
|||
name="text_model",
|
||||
)
|
||||
self.dtype = weights.dtype
|
||||
|
||||
# The vision and connector models are not quantized.
|
||||
with weights.use_loader(DefaultWeightsLoader()):
|
||||
self.vision_model = Idefics2VisionTransformer(
|
||||
prefix=f"{prefix}.model.vision_model" if prefix else "model.vision_model",
|
||||
prefix=(
|
||||
f"{prefix}.model.vision_model" if prefix else "model.vision_model"
|
||||
),
|
||||
config=vision_config,
|
||||
weights=weights,
|
||||
)
|
||||
|
||||
quantize = config.quantize
|
||||
try:
|
||||
config.quantize = None
|
||||
self.connector = Idefics2Connector(
|
||||
prefix=f"{prefix}.model.connector" if prefix else "model.connector",
|
||||
config=config,
|
||||
weights=weights,
|
||||
)
|
||||
finally:
|
||||
config.quantize = quantize
|
||||
|
||||
self.config = config
|
||||
self.image_seq_len = config.perceiver_config.resampler_n_latents
|
||||
self.image_token_id = config.image_token_id
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from safetensors import safe_open
|
||||
|
@ -306,6 +307,20 @@ class Weights:
|
|||
def get_weights_row(self, prefix: str):
|
||||
return self.weights_loader.get_weights_row(self, prefix)
|
||||
|
||||
@contextmanager
|
||||
def use_loader(self, weights_loader: WeightsLoader):
|
||||
"""
|
||||
This method is a context manager that can be used to use `Weights` with
|
||||
a different loader for the duration of the context.
|
||||
"""
|
||||
|
||||
old_loader = self.weights_loader
|
||||
self.weights_loader = weights_loader
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.weights_loader = old_loader
|
||||
|
||||
|
||||
def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> List[int]:
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue