add fp16 training

This commit is contained in:
harubaru 2022-10-15 10:57:00 -07:00
parent 9581fbc226
commit f980137430
7 changed files with 249 additions and 55 deletions

View File

@ -47,6 +47,7 @@ model:
params:
embed_dim: 4
monitor: val/rec_loss
ckpt_path: "../latent-diffusion/logs/original/checkpoints/last.ckpt"
ddconfig:
double_z: true
z_channels: 4
@ -69,22 +70,25 @@ model:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params:
penultimate: true # use 2nd last layer - https://arxiv.org/pdf/2205.11487.pdf D.1
extended_mode: 3 # extend clip context to 225 tokens - as per NAI blogpost
data:
target: main.DataModuleFromConfig
params:
batch_size: 4
num_workers: 4
batch_size: 2
num_workers: 2
wrap: false
train:
target: ldm.data.local.LocalDanbooruBase
target: ldm.data.localdanboorubase.LocalDanbooruBase
params:
data_root: '../dataset'
size: 512
mode: "train"
ucg: 0.1 # unconditional guidance training
validation:
target: ldm.data.local.LocalDanbooruBase
target: ldm.data.localdanboorubase.LocalDanbooruBase
params:
data_root: '../dataset'
size: 512
mode: "val"
val_split: 64
@ -109,9 +113,11 @@ lightning:
plot_diffusion_rows: False
N: 4
ddim_steps: 50
trainer:
benchmark: True
val_check_interval: 5000000
num_sanity_val_steps: 0
accumulate_grad_batches: 1
trainer:
precision: 16
amp_backend: "native"
strategy: "fsdp"
benchmark: True
limit_val_batches: 0
num_sanity_val_steps: 0
accumulate_grad_batches: 1

View File

@ -0,0 +1,182 @@
import os
import numpy as np
import PIL
from PIL import Image, ImageOps
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms.functional as TF
from functools import partial
import copy
import glob
import random
PIL.Image.MAX_IMAGE_PIXELS = 933120000
import torchvision
import pytorch_lightning as pl
import torch
import re
import json
import io
def resize_image(image: Image, max_size=(768,768)):
image = ImageOps.contain(image, max_size, Image.LANCZOS)
# resize to integer multiple of 64
w, h = image.size
w, h = map(lambda x: x - x % 64, (w, h))
ratio = w / h
src_ratio = image.width / image.height
src_w = w if ratio > src_ratio else image.width * h // image.height
src_h = h if ratio <= src_ratio else image.height * w // image.width
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (w, h))
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
return res
class CaptionProcessor(object):
def __init__(self, transforms, max_size, resize, random_order, LR_size):
self.transforms = transforms
self.max_size = max_size
self.resize = resize
self.random_order = random_order
self.degradation_process = partial(TF.resize, size=LR_size, interpolation=TF.InterpolationMode.NEAREST)
def __call__(self, sample):
# preprocess caption
pass
# preprocess image
image = sample['image']
image = Image.open(io.BytesIO(image))
if self.resize:
image = resize_image(image, max_size=(self.max_size, self.max_size))
image = self.transforms(image)
lr_image = copy.deepcopy(image)
image = np.array(image).astype(np.uint8)
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
# preprocess LR image
lr_image = self.degradation_process(lr_image)
lr_image = np.array(lr_image).astype(np.uint8)
sample['LR_image'] = (lr_image/127.5 - 1.0).astype(np.float32)
return sample
class LocalDanbooruBaseVAE(Dataset):
def __init__(self,
data_root='./danbooru-aesthetic',
size=256,
interpolation="bicubic",
flip_p=0.5,
crop=True,
shuffle=False,
mode='train',
val_split=64,
downscale_f=8
):
super().__init__()
self.shuffle=shuffle
self.crop = crop
print('Fetching data.')
ext = ['image']
self.image_files = []
[self.image_files.extend(glob.glob(f'{data_root}' + '/*.' + e)) for e in ext]
if mode == 'val':
self.image_files = self.image_files[:len(self.image_files)//val_split]
print(f'Constructing image map. Found {len(self.image_files)} images')
self.examples = {}
self.hashes = []
for i in self.image_files:
hash = i[len(f'{data_root}/'):].split('.')[0]
self.examples[hash] = {
'image': i
}
self.hashes.append(hash)
print(f'image map has {len(self.examples.keys())} examples')
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
image_transforms = []
image_transforms.extend([torchvision.transforms.RandomHorizontalFlip(flip_p)],)
image_transforms = torchvision.transforms.Compose(image_transforms)
self.captionprocessor = CaptionProcessor(image_transforms, self.size, True, True, int(size / downscale_f))
def random_sample(self):
return self.__getitem__(random.randint(0, self.__len__() - 1))
def sequential_sample(self, i):
if i >= self.__len__() - 1:
return self.__getitem__(0)
return self.__getitem__(i + 1)
def skip_sample(self, i):
return None
def __len__(self):
return len(self.image_files)
def __getitem__(self, i):
return self.get_image(i)
def get_image(self, i):
image = {}
try:
image_file = self.examples[self.hashes[i]]['image']
with open(image_file, 'rb') as f:
image['image'] = f.read()
image = self.captionprocessor(image)
except Exception as e:
print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}')
return self.skip_sample(i)
return image
"""
if __name__ == "__main__":
dataset = LocalBase('./danbooru-aesthetic', size=512, crop=False, mode='val')
print(dataset.__len__())
example = dataset.__getitem__(0)
print(dataset.hashes[0])
print(example['caption'])
image = example['image']
image = ((image + 1) * 127.5).astype(np.uint8)
image = Image.fromarray(image)
image.save('example.png')
"""
"""
from tqdm import tqdm
if __name__ == "__main__":
dataset = LocalDanbooruBase('./links', size=768)
import time
a = time.process_time()
for i in range(8):
example = dataset.get_image(i)
image = example['image']
image = ((image + 1) * 127.5).astype(np.uint8)
image = Image.fromarray(image)
image.save(f'example-{i}.png')
print(example['caption'])
print('time:', time.process_time()-a)
"""

View File

@ -340,17 +340,18 @@ class DDPM(pl.LightningModule):
return loss, loss_dict
def training_step(self, batch, batch_idx):
loss, loss_dict = self.shared_step(batch)
with torch.autocast('cuda'):
loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log_dict(loss_dict, prog_bar=True,
logger=True, on_step=True, on_epoch=True)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
self.log("global_step", self.global_step,
prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
if self.use_scheduler:
lr = self.optimizers().param_groups[0]['lr']
self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
return loss
@ -475,7 +476,7 @@ class LatentDiffusion(DDPM):
@rank_zero_only
@torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
def on_train_batch_start(self, batch, batch_idx):
# only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'

View File

@ -119,33 +119,35 @@ def checkpoint(func, inputs, params, flag):
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.autocast('cuda'):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
with torch.autocast('cuda'):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):

View File

@ -136,8 +136,7 @@ class SpatialRescaler(nn.Module):
return self(x)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, max_chunks=3, extended_mode=True):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, extended_mode=None):
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
@ -145,7 +144,6 @@ class FrozenCLIPEmbedder(AbstractEncoder):
self.max_length = max_length
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
self.extended_mode = extended_mode
self.max_chunks = max_chunks
self.freeze()
def freeze(self):
@ -168,7 +166,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
if self.extended_mode:
max_standard_tokens = self.max_length - 2
batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.max_chunks) - (self.max_chunks * 2), return_length=True, return_overflowing_tokens=False, padding=False,
batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.extended_mode) - (self.extended_mode * 2), return_length=True, return_overflowing_tokens=False, padding=False,
add_special_tokens=False)
# get the max length aligned to chunk size.

17
main.py
View File

@ -295,7 +295,7 @@ class ImageLogger(Callback):
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {
pl.loggers.TestTubeLogger: self._testtube,
pl.loggers.WandbLogger: self._testtube,
}
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
@ -350,7 +350,8 @@ class ImageLogger(Callback):
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
with torch.autocast('cuda'):
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images:
N = min(images[k].shape[0], self.max_images)
@ -380,7 +381,7 @@ class ImageLogger(Callback):
return True
return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
@ -518,7 +519,7 @@ if __name__ == "__main__":
# merge trainer cli with config
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# default to ddp
trainer_config["accelerator"] = "ddp"
trainer_config["accelerator"] = "gpu"
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if not "gpus" in trainer_config:
@ -556,7 +557,7 @@ if __name__ == "__main__":
}
},
}
default_logger_cfg = default_logger_cfgs["testtube"]
default_logger_cfg = default_logger_cfgs["wandb"]
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
@ -656,9 +657,11 @@ if __name__ == "__main__":
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer_kwargs["plugins"] = list()
from pytorch_lightning.plugins import DDPPlugin
trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False))
from pytorch_lightning.plugins import DDPPlugin, NativeMixedPrecisionPlugin
#trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False))
trainer_kwargs["plugins"].append(NativeMixedPrecisionPlugin(16, 'cuda', torch.cuda.amp.GradScaler(enabled=True)))
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
#trainer = Trainer(gpus=1, precision=16, amp_backend="native", strategy="deepspeed_stage_2_offload", benchmark=True, limit_val_batches=0, num_sanity_val_steps=0, accumulate_grad_batches=1)
trainer.logdir = logdir ###
# data

View File

@ -4,7 +4,7 @@ opencv-python
pudb==2019.2
imageio==2.9.0
imageio-ffmpeg==0.4.2
pytorch-lightning==1.6.0
pytorch-lightning==1.7.7
omegaconf==2.1.1
test-tube>=0.7.5
streamlit>=0.73.1
@ -18,3 +18,5 @@ git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-tra
git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
webdataset
wandb
fairscale