Add support for SD 2.1 Turbo, by converting the state dict from SGM to LDM on load
This commit is contained in:
parent
293f44e6c1
commit
6080045b2a
|
@ -230,15 +230,19 @@ def select_checkpoint():
|
||||||
return checkpoint_info
|
return checkpoint_info
|
||||||
|
|
||||||
|
|
||||||
checkpoint_dict_replacements = {
|
checkpoint_dict_replacements_sd1 = {
|
||||||
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
|
||||||
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
|
||||||
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format.
|
||||||
|
'conditioner.embedders.0.': 'cond_stage_model.',
|
||||||
|
}
|
||||||
|
|
||||||
def transform_checkpoint_dict_key(k):
|
|
||||||
for text, replacement in checkpoint_dict_replacements.items():
|
def transform_checkpoint_dict_key(k, replacements):
|
||||||
|
for text, replacement in replacements.items():
|
||||||
if k.startswith(text):
|
if k.startswith(text):
|
||||||
k = replacement + k[len(text):]
|
k = replacement + k[len(text):]
|
||||||
|
|
||||||
|
@ -249,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd):
|
||||||
pl_sd = pl_sd.pop("state_dict", pl_sd)
|
pl_sd = pl_sd.pop("state_dict", pl_sd)
|
||||||
pl_sd.pop("state_dict", None)
|
pl_sd.pop("state_dict", None)
|
||||||
|
|
||||||
|
is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024
|
||||||
|
|
||||||
sd = {}
|
sd = {}
|
||||||
for k, v in pl_sd.items():
|
for k, v in pl_sd.items():
|
||||||
new_key = transform_checkpoint_dict_key(k)
|
if is_sd2_turbo:
|
||||||
|
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo)
|
||||||
|
else:
|
||||||
|
new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1)
|
||||||
|
|
||||||
if new_key is not None:
|
if new_key is not None:
|
||||||
sd[new_key] = v
|
sd[new_key] = v
|
||||||
|
|
Loading…
Reference in New Issue