Training example parameterization

This commit is contained in:
anton-l 2022-06-15 11:21:02 +02:00
parent 31a7c75be9
commit cfe6eb1611
2 changed files with 51 additions and 39 deletions

View File

@ -1,10 +1,10 @@
import argparse
import os import os
import torch import torch
import PIL.Image
import argparse
import torch.nn.functional as F import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator from accelerate import Accelerator
from datasets import load_dataset from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel from diffusers import DDPM, DDPMScheduler, UNetModel
@ -31,44 +31,40 @@ def main(args):
dropout=0.0, dropout=0.0,
num_res_blocks=2, num_res_blocks=2,
resamp_with_conv=True, resamp_with_conv=True,
resolution=64, resolution=args.resolution,
) )
noise_scheduler = DDPMScheduler(timesteps=1000) noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
num_epochs = 100
batch_size = 16
gradient_accumulation_steps = 1
augmentations = Compose( augmentations = Compose(
[ [
Resize(64, interpolation=InterpolationMode.BILINEAR), Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
RandomCrop(64), RandomCrop(args.resolution),
RandomHorizontalFlip(), RandomHorizontalFlip(),
ToTensor(), ToTensor(),
Lambda(lambda x: x * 2 - 1), Lambda(lambda x: x * 2 - 1),
] ]
) )
dataset = load_dataset("huggan/pokemon", split="train") dataset = load_dataset(args.dataset, split="train")
def transforms(examples): def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]] images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images} return {"input": images}
dataset.set_transform(transforms) dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
lr_scheduler = get_linear_schedule_with_warmup( lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer, optimizer=optimizer,
num_warmup_steps=500, num_warmup_steps=args.warmup_steps,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps, num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
) )
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler model, optimizer, train_dataloader, lr_scheduler
) )
for epoch in range(num_epochs): for epoch in range(args.num_epochs):
model.train() model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar: with tqdm(total=len(train_dataloader), unit="ba") as pbar:
pbar.set_description(f"Epoch {epoch}") pbar.set_description(f"Epoch {epoch}")
@ -84,14 +80,15 @@ def main(args):
noise_samples[idx] = noise noise_samples[idx] = noise
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx]) noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
if step % gradient_accumulation_steps != 0: if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model): with accelerator.no_sync(model):
output = model(noisy_images, timesteps) output = model(noisy_images, timesteps)
# predict the noise # predict the noise residual
loss = F.mse_loss(output, noise_samples) loss = F.mse_loss(output, noise_samples)
accelerator.backward(loss) accelerator.backward(loss)
else: else:
output = model(noisy_images, timesteps) output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples) loss = F.mse_loss(output, noise_samples)
accelerator.backward(loss) accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
@ -103,13 +100,18 @@ def main(args):
optimizer.step() optimizer.step()
# Generate a sample image for visual inspection
torch.distributed.barrier() torch.distributed.barrier()
if args.local_rank in [-1, 0]: if args.local_rank in [-1, 0]:
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler) pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler)
generator = torch.Generator() else:
generator = generator.manual_seed(0) pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
pipeline.save_pretrained(args.output_path)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator) image = pipeline(generator=generator)
@ -120,14 +122,23 @@ def main(args):
image_pil = PIL.Image.fromarray(image_processed[0]) image_pil = PIL.Image.fromarray(image_processed[0])
# save image # save image
pipeline.save_pretrained("./pokemon-ddpm") test_dir = os.path.join(args.output_path, "test_samples")
image_pil.save(f"./pokemon-ddpm/test_{epoch}.png") os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch}.png")
torch.distributed.barrier() torch.distributed.barrier()
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of training script.") parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int) parser.add_argument("--local_rank", type=int)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--output_path", type=str, default="ddpm-model")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument( parser.add_argument(
"--mixed_precision", "--mixed_precision",
type=str, type=str,

View File

@ -214,6 +214,21 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458]) expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2 assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_glide_text2img(self):
model_id = "fusing/glide-base"
glide = GLIDE.from_pretrained(model_id)
prompt = "a pencil sketch of a corgi"
generator = torch.manual_seed(0)
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
image_slice = image[0, :3, :3, -1].cpu()
assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self): def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4) model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12) noise_scheduler = DDPMScheduler(timesteps=12)
@ -229,17 +244,3 @@ class PipelineTesterMixin(unittest.TestCase):
_ = BDDM.from_pretrained(tmpdirname) _ = BDDM.from_pretrained(tmpdirname)
# check if the same works using the DifusionPipeline class # check if the same works using the DifusionPipeline class
_ = DiffusionPipeline.from_pretrained(tmpdirname) _ = DiffusionPipeline.from_pretrained(tmpdirname)
@slow
def test_glide_text2img(self):
model_id = "fusing/glide-base"
glide = GLIDE.from_pretrained(model_id)
prompt = "a pencil sketch of a corgi"
generator = torch.manual_seed(0)
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
image_slice = image[0, :3, :3, -1].cpu()
assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2