From 46dae846dfd083a1c29c4c88e813a470c045c846 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 8 Jun 2022 13:09:49 +0000 Subject: [PATCH] add clip to ddim --- models/vision/ddim/example.py | 23 +++++++++++++++++++++++ models/vision/ddim/modeling_ddim.py | 1 + 2 files changed, 24 insertions(+) create mode 100755 models/vision/ddim/example.py diff --git a/models/vision/ddim/example.py b/models/vision/ddim/example.py new file mode 100755 index 00000000..52f75b62 --- /dev/null +++ b/models/vision/ddim/example.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +import os +import pathlib +from modeling_ddim import DDIM +import PIL.Image +import numpy as np + +model_ids = ["ddim-celeba-hq", "ddim-lsun-church", "ddim-lsun-bedroom"] + +for model_id in model_ids: + path = os.path.join("/home/patrick/images/hf", model_id) + pathlib.Path(path).mkdir(parents=True, exist_ok=True) + + ddpm = DDIM.from_pretrained("fusing/" + model_id) + image = ddpm(batch_size=4) + + image_processed = image.cpu().permute(0, 2, 3, 1) + image_processed = (image_processed + 1.0) * 127.5 + image_processed = image_processed.numpy().astype(np.uint8) + + for i in range(image_processed.shape[0]): + image_pil = PIL.Image.fromarray(image_processed[i]) + image_pil.save(os.path.join(path, f"image_{i}.png")) diff --git a/models/vision/ddim/modeling_ddim.py b/models/vision/ddim/modeling_ddim.py index fa3350f5..1e1ffea0 100644 --- a/models/vision/ddim/modeling_ddim.py +++ b/models/vision/ddim/modeling_ddim.py @@ -59,6 +59,7 @@ class DDIM(DiffusionPipeline): # predict mean of prev image pred_mean = alpha_prod_t_rsqrt * (image - beta_prod_t_sqrt * noise_residual) + pred_mean = torch.clamp(pred_mean, -1, 1) pred_mean = (1 / alpha_prod_t_prev_rsqrt) * pred_mean + coeff_2 * noise_residual # if eta > 0.0 add noise. Note eta = 1.0 essentially corresponds to DDPM