fix roundtript text_encoder

This commit is contained in:
Janek Mann 2023-02-13 17:00:53 +00:00
parent 9bfa22136b
commit 95165974ce
1 changed files with 1 additions and 2 deletions

View File

@ -746,8 +746,6 @@ def convert_open_clip_checkpoint(checkpoint):
text_model_dict = {}
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
for key in keys:
@ -756,6 +754,7 @@ def convert_open_clip_checkpoint(checkpoint):
if key in textenc_conversion_map:
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
if key.startswith("cond_stage_model.model.transformer."):
d_model = int(checkpoint[key].shape[0]/3)
new_key = key[len("cond_stage_model.model.transformer.") :]
if new_key.endswith(".in_proj_weight"):
new_key = new_key[: -len(".in_proj_weight")]