diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4b36c0e9c..0689699cf 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -212,7 +212,7 @@ class StableDiffusionModelHijack: model_embeddings = m.cond_stage_model.roberta.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self) m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self) - + elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder: model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) @@ -258,7 +258,7 @@ class StableDiffusionModelHijack: if hasattr(m, 'cond_stage_model'): delattr(m, 'cond_stage_model') - + elif type(m.cond_stage_model) == sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords: m.cond_stage_model = m.cond_stage_model.wrapped diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 9ba89dfc0..deab2f6e2 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -95,8 +95,7 @@ def guess_model_config_from_state_dict(sd, filename): if diffusion_model_input.shape[1] == 8: return config_instruct_pix2pix - - # import pdb; pdb.set_trace() + if sd.get('cond_stage_model.roberta.embeddings.word_embeddings.weight', None) is not None: if sd.get('cond_stage_model.transformation.weight').size()[0] == 1024: return config_alt_diffusion_m18 diff --git a/modules/xlmr_m18.py b/modules/xlmr_m18.py index 18785692a..a727e8655 100644 --- a/modules/xlmr_m18.py +++ b/modules/xlmr_m18.py @@ -1,4 +1,4 @@ -from transformers import BertPreTrainedModel,BertModel,BertConfig +from transformers import BertPreTrainedModel,BertConfig import torch.nn as nn import torch from transformers.models.xlm_roberta.configuration_xlm_roberta import XLMRobertaConfig @@ -28,7 +28,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): config_class = BertSeriesConfig def __init__(self, config=None, **kargs): - # modify initialization for autoloading + # modify initialization for autoloading if config is None: config = XLMRobertaConfig() config.attention_probs_dropout_prob= 0.1 @@ -80,7 +80,7 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): text["attention_mask"] = torch.tensor( text['attention_mask']).to(device) features = self(**text) - return features['projection_state'] + return features['projection_state'] def forward( self, @@ -147,8 +147,8 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): "hidden_states": outputs.hidden_states, "attentions": outputs.attentions, } - - + + # return { # 'pooler_output':pooler_output, # 'last_hidden_state':outputs.last_hidden_state, @@ -161,4 +161,4 @@ class BertSeriesModelWithTransformation(BertPreTrainedModel): class RobertaSeriesModelWithTransformation(BertSeriesModelWithTransformation): base_model_prefix = 'roberta' - config_class= RobertaSeriesConfig \ No newline at end of file + config_class= RobertaSeriesConfig