fix: avoid unused perceiver config
This commit is contained in:
parent
4f61a305f6
commit
a200d44e12
|
@ -713,82 +713,6 @@ class Idefics3Connector(nn.Module):
|
||||||
return image_hidden_states
|
return image_hidden_states
|
||||||
|
|
||||||
|
|
||||||
class PerceiverConfig:
|
|
||||||
def __init__(self, config_dict):
|
|
||||||
self._name_or_path = config_dict.get("_name_or_path", "")
|
|
||||||
self.add_cross_attention = config_dict.get("add_cross_attention", False)
|
|
||||||
self.architectures = config_dict.get("architectures", None)
|
|
||||||
self.attention_dropout = config_dict.get("attention_dropout", 0.0)
|
|
||||||
self.bad_words_ids = config_dict.get("bad_words_ids", None)
|
|
||||||
self.begin_suppress_tokens = config_dict.get("begin_suppress_tokens", None)
|
|
||||||
self.bos_token_id = config_dict.get("bos_token_id", None)
|
|
||||||
self.chunk_size_feed_forward = config_dict.get("chunk_size_feed_forward", 0)
|
|
||||||
self.cross_attention_hidden_size = config_dict.get(
|
|
||||||
"cross_attention_hidden_size", None
|
|
||||||
)
|
|
||||||
self.decoder_start_token_id = config_dict.get("decoder_start_token_id", None)
|
|
||||||
self.diversity_penalty = config_dict.get("diversity_penalty", 0.0)
|
|
||||||
self.do_sample = config_dict.get("do_sample", False)
|
|
||||||
self.early_stopping = config_dict.get("early_stopping", False)
|
|
||||||
self.encoder_no_repeat_ngram_size = config_dict.get(
|
|
||||||
"encoder_no_repeat_ngram_size", 0
|
|
||||||
)
|
|
||||||
self.eos_token_id = config_dict.get("eos_token_id", None)
|
|
||||||
self.exponential_decay_length_penalty = config_dict.get(
|
|
||||||
"exponential_decay_length_penalty", None
|
|
||||||
)
|
|
||||||
self.finetuning_task = config_dict.get("finetuning_task", None)
|
|
||||||
self.forced_bos_token_id = config_dict.get("forced_bos_token_id", None)
|
|
||||||
self.forced_eos_token_id = config_dict.get("forced_eos_token_id", None)
|
|
||||||
self.hidden_act = config_dict.get("hidden_act", "silu")
|
|
||||||
self.id2label = config_dict.get("id2label", {"0": "LABEL_0", "1": "LABEL_1"})
|
|
||||||
self.is_decoder = config_dict.get("is_decoder", False)
|
|
||||||
self.is_encoder_decoder = config_dict.get("is_encoder_decoder", False)
|
|
||||||
self.label2id = config_dict.get("label2id", {"LABEL_0": 0, "LABEL_1": 1})
|
|
||||||
self.length_penalty = config_dict.get("length_penalty", 1.0)
|
|
||||||
self.max_length = config_dict.get("max_length", 20)
|
|
||||||
self.min_length = config_dict.get("min_length", 0)
|
|
||||||
self.model_type = config_dict.get("model_type", "idefics3")
|
|
||||||
self.no_repeat_ngram_size = config_dict.get("no_repeat_ngram_size", 0)
|
|
||||||
self.num_beam_groups = config_dict.get("num_beam_groups", 1)
|
|
||||||
self.num_beams = config_dict.get("num_beams", 1)
|
|
||||||
self.num_key_value_heads = config_dict.get("num_key_value_heads", 1)
|
|
||||||
self.num_return_sequences = config_dict.get("num_return_sequences", 1)
|
|
||||||
self.output_attentions = config_dict.get("output_attentions", False)
|
|
||||||
self.output_hidden_states = config_dict.get("output_hidden_states", False)
|
|
||||||
self.output_scores = config_dict.get("output_scores", False)
|
|
||||||
self.pad_token_id = config_dict.get("pad_token_id", 128002)
|
|
||||||
self.prefix = config_dict.get("prefix", None)
|
|
||||||
self.problem_type = config_dict.get("problem_type", None)
|
|
||||||
self.pruned_heads = config_dict.get("pruned_heads", {})
|
|
||||||
self.qk_layer_norms_perceiver = config_dict.get(
|
|
||||||
"qk_layer_norms_perceiver", False
|
|
||||||
)
|
|
||||||
self.remove_invalid_values = config_dict.get("remove_invalid_values", False)
|
|
||||||
self.repetition_penalty = config_dict.get("repetition_penalty", 1.0)
|
|
||||||
self.resampler_depth = config_dict.get("resampler_depth", 6)
|
|
||||||
self.resampler_head_dim = config_dict.get("resampler_head_dim", 96)
|
|
||||||
self.resampler_n_heads = config_dict.get("resampler_n_heads", 16)
|
|
||||||
self.resampler_n_latents = config_dict.get("resampler_n_latents", 64)
|
|
||||||
self.return_dict = config_dict.get("return_dict", True)
|
|
||||||
self.return_dict_in_generate = config_dict.get("return_dict_in_generate", False)
|
|
||||||
self.sep_token_id = config_dict.get("sep_token_id", None)
|
|
||||||
self.suppress_tokens = config_dict.get("suppress_tokens", None)
|
|
||||||
self.task_specific_params = config_dict.get("task_specific_params", None)
|
|
||||||
self.temperature = config_dict.get("temperature", 1.0)
|
|
||||||
self.tf_legacy_loss = config_dict.get("tf_legacy_loss", False)
|
|
||||||
self.tie_encoder_decoder = config_dict.get("tie_encoder_decoder", False)
|
|
||||||
self.tie_word_embeddings = config_dict.get("tie_word_embeddings", True)
|
|
||||||
self.tokenizer_class = config_dict.get("tokenizer_class", None)
|
|
||||||
self.top_k = config_dict.get("top_k", 50)
|
|
||||||
self.top_p = config_dict.get("top_p", 1.0)
|
|
||||||
self.torch_dtype = config_dict.get("torch_dtype", None)
|
|
||||||
self.torchscript = config_dict.get("torchscript", False)
|
|
||||||
self.transformers_version = config_dict.get("transformers_version", "4.43.2")
|
|
||||||
self.typical_p = config_dict.get("typical_p", 1.0)
|
|
||||||
self.use_bfloat16 = config_dict.get("use_bfloat16", False)
|
|
||||||
|
|
||||||
|
|
||||||
class Idefics3ForConditionalGeneration(nn.Module):
|
class Idefics3ForConditionalGeneration(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -824,10 +748,6 @@ class Idefics3ForConditionalGeneration(nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
config.text_config.perceiver_config = PerceiverConfig(
|
|
||||||
config_dict=config.text_config.perceiver_config
|
|
||||||
)
|
|
||||||
self.image_seq_len = config.text_config.perceiver_config.resampler_n_latents
|
|
||||||
self.image_token_id = config.image_token_id
|
self.image_token_id = config.image_token_id
|
||||||
self.pad_token_id = (
|
self.pad_token_id = (
|
||||||
config.pad_token_id if config.pad_token_id is not None else -1
|
config.pad_token_id if config.pad_token_id is not None else -1
|
||||||
|
|
Loading…
Reference in New Issue