parent
3b7f514a1c
commit
5311f564ed
|
@ -6,19 +6,20 @@ from tqdm.auto import tqdm
|
|||
|
||||
|
||||
class ScoreSdeVePipeline(DiffusionPipeline):
|
||||
def __init__(self, model, scheduler):
|
||||
def __init__(self, unet, scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler)
|
||||
self.register_modules(unet=unet, scheduler=scheduler)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, batch_size=1, num_inference_steps=2000, generator=None, torch_device=None, output_type="pil"):
|
||||
|
||||
if torch_device is None:
|
||||
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
img_size = self.model.config.sample_size
|
||||
img_size = self.unet.config.sample_size
|
||||
shape = (batch_size, 3, img_size, img_size)
|
||||
|
||||
model = self.model.to(torch_device)
|
||||
model = self.unet.to(torch_device)
|
||||
|
||||
sample = torch.randn(*shape) * self.scheduler.config.sigma_max
|
||||
sample = sample.to(torch_device)
|
||||
|
@ -31,7 +32,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
|
||||
# correction step
|
||||
for _ in range(self.scheduler.correct_steps):
|
||||
model_output = self.model(sample, sigma_t)["sample"]
|
||||
model_output = self.unet(sample, sigma_t)["sample"]
|
||||
sample = self.scheduler.step_correct(model_output, sample)["prev_sample"]
|
||||
|
||||
# prediction step
|
||||
|
@ -40,7 +41,7 @@ class ScoreSdeVePipeline(DiffusionPipeline):
|
|||
|
||||
sample, sample_mean = output["prev_sample"], output["prev_sample_mean"]
|
||||
|
||||
sample = sample.clamp(0, 1)
|
||||
sample = sample_mean.clamp(0, 1)
|
||||
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
|
||||
if output_type == "pil":
|
||||
sample = self.numpy_to_pil(sample)
|
||||
|
|
|
@ -848,15 +848,12 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
|
||||
@slow
|
||||
def test_score_sde_ve_pipeline(self):
|
||||
model = UNet2DModel.from_pretrained("google/ncsnpp-church-256")
|
||||
model_id = "google/ncsnpp-church-256"
|
||||
model = UNet2DModel.from_pretrained(model_id)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
scheduler = ScoreSdeVeScheduler.from_config(model_id)
|
||||
|
||||
scheduler = ScoreSdeVeScheduler.from_config("google/ncsnpp-church-256")
|
||||
|
||||
sde_ve = ScoreSdeVePipeline(model=model, scheduler=scheduler)
|
||||
sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler)
|
||||
|
||||
torch.manual_seed(0)
|
||||
image = sde_ve(num_inference_steps=300, output_type="numpy")["sample"]
|
||||
|
|
Loading…
Reference in New Issue