Include CLIPTextModel parameters in conversion (#695)
This commit is contained in:
parent
08d4fb6e9f
commit
b9eea06e9f
|
@ -595,6 +595,22 @@ def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||||
return hf_model
|
return hf_model
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ldm_clip_checkpoint(checkpoint):
|
||||||
|
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
|
|
||||||
|
keys = list(checkpoint.keys())
|
||||||
|
|
||||||
|
text_model_dict = {}
|
||||||
|
|
||||||
|
for key in keys:
|
||||||
|
if key.startswith("cond_stage_model.transformer"):
|
||||||
|
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||||
|
|
||||||
|
text_model.load_state_dict(text_model_dict)
|
||||||
|
|
||||||
|
return text_model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
@ -668,7 +684,7 @@ if __name__ == "__main__":
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||||
if text_model_type == "FrozenCLIPEmbedder":
|
if text_model_type == "FrozenCLIPEmbedder":
|
||||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||||
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||||
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
feature_extractor = AutoFeatureExtractor.from_pretrained("CompVis/stable-diffusion-safety-checker")
|
||||||
|
|
Loading…
Reference in New Issue