Training example parameterization

This commit is contained in:
anton-l 2022-06-15 11:21:02 +02:00
parent 31a7c75be9
commit cfe6eb1611
2 changed files with 51 additions and 39 deletions

View File

@ -1,10 +1,10 @@
import argparse
import os
import torch
import PIL.Image
import argparse
import torch.nn.functional as F
import PIL.Image
from accelerate import Accelerator
from datasets import load_dataset
from diffusers import DDPM, DDPMScheduler, UNetModel
@ -31,44 +31,40 @@ def main(args):
dropout=0.0,
num_res_blocks=2,
resamp_with_conv=True,
resolution=64,
resolution=args.resolution,
)
noise_scheduler = DDPMScheduler(timesteps=1000)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 100
batch_size = 16
gradient_accumulation_steps = 1
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
augmentations = Compose(
[
Resize(64, interpolation=InterpolationMode.BILINEAR),
RandomCrop(64),
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
RandomCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Lambda(lambda x: x * 2 - 1),
]
)
dataset = load_dataset("huggan/pokemon", split="train")
dataset = load_dataset(args.dataset, split="train")
def transforms(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=500,
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
num_warmup_steps=args.warmup_steps,
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
for epoch in range(num_epochs):
for epoch in range(args.num_epochs):
model.train()
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
pbar.set_description(f"Epoch {epoch}")
@ -84,14 +80,15 @@ def main(args):
noise_samples[idx] = noise
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
if step % gradient_accumulation_steps != 0:
if step % args.gradient_accumulation_steps != 0:
with accelerator.no_sync(model):
output = model(noisy_images, timesteps)
# predict the noise
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
accelerator.backward(loss)
else:
output = model(noisy_images, timesteps)
# predict the noise residual
loss = F.mse_loss(output, noise_samples)
accelerator.backward(loss)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
@ -103,13 +100,18 @@ def main(args):
optimizer.step()
# Generate a sample image for visual inspection
torch.distributed.barrier()
if args.local_rank in [-1, 0]:
model.eval()
with torch.no_grad():
pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler)
generator = torch.Generator()
generator = generator.manual_seed(0)
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler)
else:
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
pipeline.save_pretrained(args.output_path)
generator = torch.manual_seed(0)
# run pipeline in inference (sample random noise and denoise)
image = pipeline(generator=generator)
@ -120,22 +122,31 @@ def main(args):
image_pil = PIL.Image.fromarray(image_processed[0])
# save image
pipeline.save_pretrained("./pokemon-ddpm")
image_pil.save(f"./pokemon-ddpm/test_{epoch}.png")
test_dir = os.path.join(args.output_path, "test_samples")
os.makedirs(test_dir, exist_ok=True)
image_pil.save(f"{test_dir}/{epoch}.png")
torch.distributed.barrier()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Simple example of training script.")
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument("--local_rank", type=int)
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
parser.add_argument("--resolution", type=int, default=64)
parser.add_argument("--output_path", type=str, default="ddpm-model")
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--warmup_steps", type=int, default=500)
parser.add_argument(
"--mixed_precision",
type=str,
default="no",
choices=["no", "fp16", "bf16"],
help="Whether to use mixed precision. Choose"
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
"and an Nvidia Ampere GPU.",
)
args = parser.parse_args()

View File

@ -214,6 +214,21 @@ class PipelineTesterMixin(unittest.TestCase):
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
@slow
def test_glide_text2img(self):
model_id = "fusing/glide-base"
glide = GLIDE.from_pretrained(model_id)
prompt = "a pencil sketch of a corgi"
generator = torch.manual_seed(0)
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
image_slice = image[0, :3, :3, -1].cpu()
assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12)
@ -229,17 +244,3 @@ class PipelineTesterMixin(unittest.TestCase):
_ = BDDM.from_pretrained(tmpdirname)
# check if the same works using the DifusionPipeline class
_ = DiffusionPipeline.from_pretrained(tmpdirname)
@slow
def test_glide_text2img(self):
model_id = "fusing/glide-base"
glide = GLIDE.from_pretrained(model_id)
prompt = "a pencil sketch of a corgi"
generator = torch.manual_seed(0)
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
image_slice = image[0, :3, :3, -1].cpu()
assert image.shape == (1, 256, 256, 3)
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2