fix roundtript text_encoder
This commit is contained in:
parent
9bfa22136b
commit
95165974ce
|
@ -746,8 +746,6 @@ def convert_open_clip_checkpoint(checkpoint):
|
|||
|
||||
text_model_dict = {}
|
||||
|
||||
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
||||
|
||||
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
||||
|
||||
for key in keys:
|
||||
|
@ -756,6 +754,7 @@ def convert_open_clip_checkpoint(checkpoint):
|
|||
if key in textenc_conversion_map:
|
||||
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
||||
if key.startswith("cond_stage_model.model.transformer."):
|
||||
d_model = int(checkpoint[key].shape[0]/3)
|
||||
new_key = key[len("cond_stage_model.model.transformer.") :]
|
||||
if new_key.endswith(".in_proj_weight"):
|
||||
new_key = new_key[: -len(".in_proj_weight")]
|
||||
|
|
Loading…
Reference in New Issue