Training example parameterization
This commit is contained in:
parent
31a7c75be9
commit
cfe6eb1611
|
@ -1,10 +1,10 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
import PIL.Image
|
||||
import argparse
|
||||
import torch.nn.functional as F
|
||||
|
||||
import PIL.Image
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from diffusers import DDPM, DDPMScheduler, UNetModel
|
||||
|
@ -31,44 +31,40 @@ def main(args):
|
|||
dropout=0.0,
|
||||
num_res_blocks=2,
|
||||
resamp_with_conv=True,
|
||||
resolution=64,
|
||||
resolution=args.resolution,
|
||||
)
|
||||
noise_scheduler = DDPMScheduler(timesteps=1000)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
||||
|
||||
num_epochs = 100
|
||||
batch_size = 16
|
||||
gradient_accumulation_steps = 1
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
augmentations = Compose(
|
||||
[
|
||||
Resize(64, interpolation=InterpolationMode.BILINEAR),
|
||||
RandomCrop(64),
|
||||
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
|
||||
RandomCrop(args.resolution),
|
||||
RandomHorizontalFlip(),
|
||||
ToTensor(),
|
||||
Lambda(lambda x: x * 2 - 1),
|
||||
]
|
||||
)
|
||||
dataset = load_dataset("huggan/pokemon", split="train")
|
||||
dataset = load_dataset(args.dataset, split="train")
|
||||
|
||||
def transforms(examples):
|
||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||
return {"input": images}
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
||||
|
||||
lr_scheduler = get_linear_schedule_with_warmup(
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=500,
|
||||
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
|
||||
num_warmup_steps=args.warmup_steps,
|
||||
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for epoch in range(args.num_epochs):
|
||||
model.train()
|
||||
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
|
||||
pbar.set_description(f"Epoch {epoch}")
|
||||
|
@ -84,14 +80,15 @@ def main(args):
|
|||
noise_samples[idx] = noise
|
||||
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
|
||||
|
||||
if step % gradient_accumulation_steps != 0:
|
||||
if step % args.gradient_accumulation_steps != 0:
|
||||
with accelerator.no_sync(model):
|
||||
output = model(noisy_images, timesteps)
|
||||
# predict the noise
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(output, noise_samples)
|
||||
accelerator.backward(loss)
|
||||
else:
|
||||
output = model(noisy_images, timesteps)
|
||||
# predict the noise residual
|
||||
loss = F.mse_loss(output, noise_samples)
|
||||
accelerator.backward(loss)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
@ -103,13 +100,18 @@ def main(args):
|
|||
|
||||
optimizer.step()
|
||||
|
||||
# Generate a sample image for visual inspection
|
||||
torch.distributed.barrier()
|
||||
if args.local_rank in [-1, 0]:
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||
pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler)
|
||||
generator = torch.Generator()
|
||||
generator = generator.manual_seed(0)
|
||||
else:
|
||||
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
|
||||
pipeline.save_pretrained(args.output_path)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
# run pipeline in inference (sample random noise and denoise)
|
||||
image = pipeline(generator=generator)
|
||||
|
||||
|
@ -120,14 +122,23 @@ def main(args):
|
|||
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||
|
||||
# save image
|
||||
pipeline.save_pretrained("./pokemon-ddpm")
|
||||
image_pil.save(f"./pokemon-ddpm/test_{epoch}.png")
|
||||
test_dir = os.path.join(args.output_path, "test_samples")
|
||||
os.makedirs(test_dir, exist_ok=True)
|
||||
image_pil.save(f"{test_dir}/{epoch}.png")
|
||||
torch.distributed.barrier()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Simple example of training script.")
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument("--local_rank", type=int)
|
||||
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
|
||||
parser.add_argument("--resolution", type=int, default=64)
|
||||
parser.add_argument("--output_path", type=str, default="ddpm-model")
|
||||
parser.add_argument("--batch_size", type=int, default=16)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
|
||||
parser.add_argument("--lr", type=float, default=1e-4)
|
||||
parser.add_argument("--warmup_steps", type=int, default=500)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
|
@ -214,6 +214,21 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
@slow
|
||||
def test_glide_text2img(self):
|
||||
model_id = "fusing/glide-base"
|
||||
glide = GLIDE.from_pretrained(model_id)
|
||||
|
||||
prompt = "a pencil sketch of a corgi"
|
||||
generator = torch.manual_seed(0)
|
||||
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
|
||||
|
||||
image_slice = image[0, :3, :3, -1].cpu()
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
||||
def test_module_from_pipeline(self):
|
||||
model = DiffWave(num_res_layers=4)
|
||||
noise_scheduler = DDPMScheduler(timesteps=12)
|
||||
|
@ -229,17 +244,3 @@ class PipelineTesterMixin(unittest.TestCase):
|
|||
_ = BDDM.from_pretrained(tmpdirname)
|
||||
# check if the same works using the DifusionPipeline class
|
||||
_ = DiffusionPipeline.from_pretrained(tmpdirname)
|
||||
@slow
|
||||
def test_glide_text2img(self):
|
||||
model_id = "fusing/glide-base"
|
||||
glide = GLIDE.from_pretrained(model_id)
|
||||
|
||||
prompt = "a pencil sketch of a corgi"
|
||||
generator = torch.manual_seed(0)
|
||||
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
|
||||
|
||||
image_slice = image[0, :3, :3, -1].cpu()
|
||||
|
||||
assert image.shape == (1, 256, 256, 3)
|
||||
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
|
||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||
|
|
Loading…
Reference in New Issue