EveryDream2trainer/scripts/mt_grid.py

89 lines
4.2 KiB
Python

from diffusers import StableDiffusionPipeline
import torch
from torch.cuda.amp import autocast
import os
import argparse
from PIL import Image, ImageDraw, ImageFont
def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen,
steps: int = 30, batch_size: int = 1):
"""
generates a single sample at a given cfg scale and saves it to disk
"""
with autocast():
images = pipe(prompt,
num_inference_steps=steps,
num_images_per_prompt=batch_size,
guidance_scale=cfg,
generator=gen,
height=height,
width=width,
).images
return images
def generate_simple(prompt, model):
pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda")
images = __generate_sample(pipe, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)
return images[0]
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--epochs", type=int, default=60)
args = argparser.parse_args()
epochs = args.epochs
path = "/mnt/nvme/mt/val"
model = None
if epochs == 100:
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep100-gs159000"
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep100-gs159000"
elif epochs == 80:
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep80-gs127200"
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep80-gs127200"
elif epochs == 60:
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400"
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep60-gs95400"
else:
raise ValueError("epochs must be 100, 80, or 60")
pipe1 = StableDiffusionPipeline.from_pretrained(model1).to("cuda")
pipe2 = StableDiffusionPipeline.from_pretrained(model2).to("cuda")
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith(".txt") and not file.endswith("file_list.txt"):
txt_path = os.path.join(root, file)
with open(txt_path, "r", encoding="utf-8") as f:
prompt = f.readline()
generated_image1 = __generate_sample(pipe1, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0]
short_prompt = prompt.split(",")[0]
generated_image2 = __generate_sample(pipe2, short_prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0]
print(short_prompt)
gt_path = txt_path.replace(".txt", ".png")
print(f"Loading gt_path {gt_path}")
ground_truth_image = Image.open(gt_path)
ground_truth_image = ground_truth_image.resize((512, 512))
combined_image = Image.new("RGB", (1536, 576), color=(96, 96, 96))
combined_image.paste(ground_truth_image, (0, 0))
combined_image.paste(generated_image1, (512, 0))
combined_image.paste(generated_image2, (1024, 0))
draw = ImageDraw.Draw(combined_image)
font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 18)
draw.text((0, 510), f"epochs={epochs}", font=font)
draw.text((200, 510), "↑ ground truth ↑", font=font)
draw.text((650, 510), "↑ trained&generated full caption↑", font=font)
draw.text((1140, 510), "↑ trained&generated short caption ↑", font=font)
font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 24)
draw.text((100, 536), prompt, font=font)
draw.text((1240, 537), short_prompt, font=font)
output_path = os.path.join("/mnt/nvme/mt", str(epochs), f"{file}_compare.png")
print(f"Saving to {output_path}")
combined_image.save(output_path)