41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
import argparse
|
|
|
|
from diffusers import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
|
|
|
parser.add_argument(
|
|
"--txt2img_unclip",
|
|
default="kakaobrain/karlo-v1-alpha",
|
|
type=str,
|
|
required=False,
|
|
help="The pretrained txt2img unclip.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
txt2img = UnCLIPPipeline.from_pretrained(args.txt2img_unclip)
|
|
|
|
feature_extractor = CLIPImageProcessor()
|
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
|
|
|
img2img = UnCLIPImageVariationPipeline(
|
|
decoder=txt2img.decoder,
|
|
text_encoder=txt2img.text_encoder,
|
|
tokenizer=txt2img.tokenizer,
|
|
text_proj=txt2img.text_proj,
|
|
feature_extractor=feature_extractor,
|
|
image_encoder=image_encoder,
|
|
super_res_first=txt2img.super_res_first,
|
|
super_res_last=txt2img.super_res_last,
|
|
decoder_scheduler=txt2img.decoder_scheduler,
|
|
super_res_scheduler=txt2img.super_res_scheduler,
|
|
)
|
|
|
|
img2img.save_pretrained(args.dump_path)
|