This commit is contained in:
Patrick von Platen 2022-06-26 11:02:57 +00:00
parent 135acd83af
commit d5c527a499
1 changed files with 2 additions and 53 deletions

View File

@ -1,14 +1,9 @@
#!/usr/bin/env python3
import numpy as np
import torch
import PIL
from diffusers import DiffusionPipeline
# TODO(Patrick, Anton, Suraj) - rename `x` to better variable names
class ScoreSdeVePipeline(DiffusionPipeline):
def __init__(self, model, scheduler):
super().__init__()
@ -23,7 +18,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
model = self.model.to(device)
centered = False
# TODO(Patrick) move to scheduler config
n_steps = 1
x = torch.randn(*shape) * self.scheduler.config.sigma_max
@ -45,50 +40,4 @@ class ScoreSdeVePipeline(DiffusionPipeline):
x, x_mean = self.scheduler.step_pred(result, x, t)
x = x_mean
if centered:
x = (x + 1.0) / 2.0
return x
# from configs.ve import ffhq_ncsnpp_continuous as configs
# from configs.ve import cifar10_ncsnpp_continuous as configs
# ckpt_filename = "exp/ve/cifar10_ncsnpp_continuous/checkpoint_24.pth"
# ckpt_filename = "exp/ve/ffhq_1024_ncsnpp_continuous/checkpoint_60.pth"
# Note usually we need to restore ema etc...
# ema restored checkpoint used from below
# pipeline = ScoreSdeVePipeline.from_pretrained("/home/patrick/ffhq_ncsnpp")
# x = pipeline(num_inference_steps=2)
# for 5 cifar10
# x_sum = 106071.9922
# x_mean = 34.52864456176758
# for 1000 cifar10
# x_sum = 461.9700
# x_mean = 0.1504
# for N=2 for 1024
# x_sum = 3382810112.0
# x_mean = 1075.366455078125
#
#
# def check_x_sum_x_mean(x, x_sum, x_mean):
# assert (x.abs().sum() - x_sum).abs().cpu().item() < 1e-2, f"sum wrong {x.abs().sum()}"
# assert (x.abs().mean() - x_mean).abs().cpu().item() < 1e-4, f"mean wrong {x.abs().mean()}"
#
#
# check_x_sum_x_mean(x, x_sum, x_mean)
#
#
# def save_image(x):
# image_processed = np.clip(x.permute(0, 2, 3, 1).cpu().numpy() * 255, 0, 255).astype(np.uint8)
# image_pil = PIL.Image.fromarray(image_processed[0])
# image_pil.save("../images/hey.png")
#
#
# save_image(x)
return x_mean