parent
3b7f514a1c
commit
5311f564ed
|
@ -6,19 +6,20 @@ from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
|
||||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||||
def __init__(self, model, scheduler):
|
def __init__(self, unet, scheduler):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.register_modules(model=model, scheduler=scheduler)
|
self.register_modules(unet=unet, scheduler=scheduler)
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
|
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
|
||||||
|
|
||||||
if torch_device is None:
|
if torch_device is None:
|
||||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
img_size = self.model.config.sample_size
|
img_size = self.unet.config.sample_size
|
||||||
shape = (batch_size, 3, img_size, img_size)
|
shape = (batch_size, 3, img_size, img_size)
|
||||||
|
|
||||||
model = self.model.to(torch_device)
|
model = self.unet.to(torch_device)
|
||||||
|
|
||||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||||
sample = sample.to(torch_device)
|
sample = sample.to(torch_device)
|
||||||
|
@ -31,7 +32,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||||
|
|
||||||
# correction step
|
# correction step
|
||||||
for _ in range(self.scheduler.correct_steps):
|
for _ in range(self.scheduler.correct_steps):
|
||||||
model_output = self.model(sample, sigma_t)["sample"]
|
model_output = self.unet(sample, sigma_t)["sample"]
|
||||||
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
|
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
|
||||||
|
|
||||||
# prediction step
|
# prediction step
|
||||||
|
@ -40,7 +41,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
||||||
|
|
||||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||||
|
|
||||||
sample = sample.clamp(0, 1)
|
sample = sample_mean.clamp(0, 1)
|
||||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||||
if output_type == "pil":
|
if output_type == "pil":
|
||||||
sample = self.numpy_to_pil(sample)
|
sample = self.numpy_to_pil(sample)
|
||||||
|
|
|
@ -848,15 +848,12 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_score_sde_ve_pipeline(self):
|
def test_score_sde_ve_pipeline(self):
|
||||||
model = UNet2DModel.from_pretrained("google/ncsnpp-church-256")
|
model_id = "google/ncsnpp-church-256"
|
||||||
|
model = UNet2DModel.from_pretrained(model_id)
|
||||||
|
|
||||||
torch.manual_seed(0)
|
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed_all(0)
|
|
||||||
|
|
||||||
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
|
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
||||||
|
|
||||||
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
|
|
||||||
|
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
|
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
|
||||||
|
|
Loading…
Reference in New Issue