2022-06-02 07:55:32 -06:00
|
|
|
#!/usr/bin/env python3
|
2022-06-07 07:03:53 -06:00
|
|
|
import os
|
|
|
|
import pathlib
|
2022-06-06 09:03:41 -06:00
|
|
|
from modeling_ddpm import DDPM
|
2022-06-07 07:03:53 -06:00
|
|
|
import PIL.Image
|
|
|
|
import numpy as np
|
2022-06-06 09:03:41 -06:00
|
|
|
|
2022-06-07 07:24:36 -06:00
|
|
|
model_ids = ["ddpm-lsun-cat", "ddpm-lsun-cat-ema", "ddpm-lsun-church-ema", "ddpm-lsun-church", "ddpm-lsun-bedroom", "ddpm-lsun-bedroom-ema", "ddpm-cifar10-ema", "ddpm-cifar10", "ddpm-celeba-hq", "ddpm-celeba-hq-ema"]
|
2022-06-06 10:00:06 -06:00
|
|
|
|
2022-06-07 07:03:53 -06:00
|
|
|
for model_id in model_ids:
|
|
|
|
path = os.path.join("/home/patrick/images/hf", model_id)
|
|
|
|
pathlib.Path(path).mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
ddpm = DDPM.from_pretrained("fusing/" + model_id)
|
|
|
|
image = ddpm(batch_size=4)
|
|
|
|
|
|
|
|
image_processed = image.cpu().permute(0, 2, 3, 1)
|
|
|
|
image_processed = (image_processed + 1.0) * 127.5
|
|
|
|
image_processed = image_processed.numpy().astype(np.uint8)
|
2022-06-02 07:55:32 -06:00
|
|
|
|
2022-06-07 07:03:53 -06:00
|
|
|
for i in range(image_processed.shape[0]):
|
|
|
|
image_pil = PIL.Image.fromarray(image_processed[i])
|
|
|
|
image_pil.save(os.path.join(path, f"image_{i}.png"))
|