Training example parameterization
This commit is contained in:
parent
31a7c75be9
commit
cfe6eb1611
|
@ -1,10 +1,10 @@
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import PIL.Image
|
|
||||||
import argparse
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from diffusers import DDPM, DDPMScheduler, UNetModel
|
from diffusers import DDPM, DDPMScheduler, UNetModel
|
||||||
|
@ -31,44 +31,40 @@ def main(args):
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
num_res_blocks=2,
|
num_res_blocks=2,
|
||||||
resamp_with_conv=True,
|
resamp_with_conv=True,
|
||||||
resolution=64,
|
resolution=args.resolution,
|
||||||
)
|
)
|
||||||
noise_scheduler = DDPMScheduler(timesteps=1000)
|
noise_scheduler = DDPMScheduler(timesteps=1000)
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||||
|
|
||||||
num_epochs = 100
|
|
||||||
batch_size = 16
|
|
||||||
gradient_accumulation_steps = 1
|
|
||||||
|
|
||||||
augmentations = Compose(
|
augmentations = Compose(
|
||||||
[
|
[
|
||||||
Resize(64, interpolation=InterpolationMode.BILINEAR),
|
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
|
||||||
RandomCrop(64),
|
RandomCrop(args.resolution),
|
||||||
RandomHorizontalFlip(),
|
RandomHorizontalFlip(),
|
||||||
ToTensor(),
|
ToTensor(),
|
||||||
Lambda(lambda x: x * 2 - 1),
|
Lambda(lambda x: x * 2 - 1),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
dataset = load_dataset("huggan/pokemon", split="train")
|
dataset = load_dataset(args.dataset, split="train")
|
||||||
|
|
||||||
def transforms(examples):
|
def transforms(examples):
|
||||||
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
|
||||||
return {"input": images}
|
return {"input": images}
|
||||||
|
|
||||||
dataset.set_transform(transforms)
|
dataset.set_transform(transforms)
|
||||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
||||||
|
|
||||||
lr_scheduler = get_linear_schedule_with_warmup(
|
lr_scheduler = get_linear_schedule_with_warmup(
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=500,
|
num_warmup_steps=args.warmup_steps,
|
||||||
num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
|
num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||||
model, optimizer, train_dataloader, lr_scheduler
|
model, optimizer, train_dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
for epoch in range(num_epochs):
|
for epoch in range(args.num_epochs):
|
||||||
model.train()
|
model.train()
|
||||||
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
|
with tqdm(total=len(train_dataloader), unit="ba") as pbar:
|
||||||
pbar.set_description(f"Epoch {epoch}")
|
pbar.set_description(f"Epoch {epoch}")
|
||||||
|
@ -84,14 +80,15 @@ def main(args):
|
||||||
noise_samples[idx] = noise
|
noise_samples[idx] = noise
|
||||||
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
|
noisy_images[idx] = noise_scheduler.forward_step(clean_images[idx], noise, timesteps[idx])
|
||||||
|
|
||||||
if step % gradient_accumulation_steps != 0:
|
if step % args.gradient_accumulation_steps != 0:
|
||||||
with accelerator.no_sync(model):
|
with accelerator.no_sync(model):
|
||||||
output = model(noisy_images, timesteps)
|
output = model(noisy_images, timesteps)
|
||||||
# predict the noise
|
# predict the noise residual
|
||||||
loss = F.mse_loss(output, noise_samples)
|
loss = F.mse_loss(output, noise_samples)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
else:
|
else:
|
||||||
output = model(noisy_images, timesteps)
|
output = model(noisy_images, timesteps)
|
||||||
|
# predict the noise residual
|
||||||
loss = F.mse_loss(output, noise_samples)
|
loss = F.mse_loss(output, noise_samples)
|
||||||
accelerator.backward(loss)
|
accelerator.backward(loss)
|
||||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||||
|
@ -103,13 +100,18 @@ def main(args):
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
# Generate a sample image for visual inspection
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
model.eval()
|
model.eval()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler)
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
|
||||||
generator = torch.Generator()
|
pipeline = DDPM(unet=model.module, noise_scheduler=noise_scheduler)
|
||||||
generator = generator.manual_seed(0)
|
else:
|
||||||
|
pipeline = DDPM(unet=model, noise_scheduler=noise_scheduler)
|
||||||
|
pipeline.save_pretrained(args.output_path)
|
||||||
|
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
# run pipeline in inference (sample random noise and denoise)
|
# run pipeline in inference (sample random noise and denoise)
|
||||||
image = pipeline(generator=generator)
|
image = pipeline(generator=generator)
|
||||||
|
|
||||||
|
@ -120,22 +122,31 @@ def main(args):
|
||||||
image_pil = PIL.Image.fromarray(image_processed[0])
|
image_pil = PIL.Image.fromarray(image_processed[0])
|
||||||
|
|
||||||
# save image
|
# save image
|
||||||
pipeline.save_pretrained("./pokemon-ddpm")
|
test_dir = os.path.join(args.output_path, "test_samples")
|
||||||
image_pil.save(f"./pokemon-ddpm/test_{epoch}.png")
|
os.makedirs(test_dir, exist_ok=True)
|
||||||
|
image_pil.save(f"{test_dir}/{epoch}.png")
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Simple example of training script.")
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||||
parser.add_argument("--local_rank", type=int)
|
parser.add_argument("--local_rank", type=int)
|
||||||
|
parser.add_argument("--dataset", type=str, default="huggan/flowers-102-categories")
|
||||||
|
parser.add_argument("--resolution", type=int, default=64)
|
||||||
|
parser.add_argument("--output_path", type=str, default="ddpm-model")
|
||||||
|
parser.add_argument("--batch_size", type=int, default=16)
|
||||||
|
parser.add_argument("--num_epochs", type=int, default=100)
|
||||||
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-4)
|
||||||
|
parser.add_argument("--warmup_steps", type=int, default=500)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mixed_precision",
|
"--mixed_precision",
|
||||||
type=str,
|
type=str,
|
||||||
default="no",
|
default="no",
|
||||||
choices=["no", "fp16", "bf16"],
|
choices=["no", "fp16", "bf16"],
|
||||||
help="Whether to use mixed precision. Choose"
|
help="Whether to use mixed precision. Choose"
|
||||||
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
"between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
|
||||||
"and an Nvidia Ampere GPU.",
|
"and an Nvidia Ampere GPU.",
|
||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
|
@ -214,6 +214,21 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
|
||||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_glide_text2img(self):
|
||||||
|
model_id = "fusing/glide-base"
|
||||||
|
glide = GLIDE.from_pretrained(model_id)
|
||||||
|
|
||||||
|
prompt = "a pencil sketch of a corgi"
|
||||||
|
generator = torch.manual_seed(0)
|
||||||
|
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
|
||||||
|
|
||||||
|
image_slice = image[0, :3, :3, -1].cpu()
|
||||||
|
|
||||||
|
assert image.shape == (1, 256, 256, 3)
|
||||||
|
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
|
||||||
|
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
||||||
|
|
||||||
def test_module_from_pipeline(self):
|
def test_module_from_pipeline(self):
|
||||||
model = DiffWave(num_res_layers=4)
|
model = DiffWave(num_res_layers=4)
|
||||||
noise_scheduler = DDPMScheduler(timesteps=12)
|
noise_scheduler = DDPMScheduler(timesteps=12)
|
||||||
|
@ -229,17 +244,3 @@ class PipelineTesterMixin(unittest.TestCase):
|
||||||
_ = BDDM.from_pretrained(tmpdirname)
|
_ = BDDM.from_pretrained(tmpdirname)
|
||||||
# check if the same works using the DifusionPipeline class
|
# check if the same works using the DifusionPipeline class
|
||||||
_ = DiffusionPipeline.from_pretrained(tmpdirname)
|
_ = DiffusionPipeline.from_pretrained(tmpdirname)
|
||||||
@slow
|
|
||||||
def test_glide_text2img(self):
|
|
||||||
model_id = "fusing/glide-base"
|
|
||||||
glide = GLIDE.from_pretrained(model_id)
|
|
||||||
|
|
||||||
prompt = "a pencil sketch of a corgi"
|
|
||||||
generator = torch.manual_seed(0)
|
|
||||||
image = glide(prompt, generator=generator, num_inference_steps_upscale=20)
|
|
||||||
|
|
||||||
image_slice = image[0, :3, :3, -1].cpu()
|
|
||||||
|
|
||||||
assert image.shape == (1, 256, 256, 3)
|
|
||||||
expected_slice = torch.tensor([0.7119, 0.7073, 0.6460, 0.7780, 0.7423, 0.6926, 0.7378, 0.7189, 0.7784])
|
|
||||||
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
|
|
||||||
|
|
Loading…
Reference in New Issue