add fp16 training

This commit is contained in:
harubaru 2022-10-15 10:57:00 -07:00
parent 9581fbc226
commit f980137430
7 changed files with 249 additions and 55 deletions

View File

@ -47,6 +47,7 @@ model:
params: params:
embed_dim: 4 embed_dim: 4
monitor: val/rec_loss monitor: val/rec_loss
ckpt_path: "../latent-diffusion/logs/original/checkpoints/last.ckpt"
ddconfig: ddconfig:
double_z: true double_z: true
z_channels: 4 z_channels: 4
@ -69,22 +70,25 @@ model:
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
params: params:
penultimate: true # use 2nd last layer - https://arxiv.org/pdf/2205.11487.pdf D.1 penultimate: true # use 2nd last layer - https://arxiv.org/pdf/2205.11487.pdf D.1
extended_mode: 3 # extend clip context to 225 tokens - as per NAI blogpost
data: data:
target: main.DataModuleFromConfig target: main.DataModuleFromConfig
params: params:
batch_size: 4 batch_size: 2
num_workers: 4 num_workers: 2
wrap: false wrap: false
train: train:
target: ldm.data.local.LocalDanbooruBase target: ldm.data.localdanboorubase.LocalDanbooruBase
params: params:
data_root: '../dataset'
size: 512 size: 512
mode: "train" mode: "train"
ucg: 0.1 # unconditional guidance training ucg: 0.1 # unconditional guidance training
validation: validation:
target: ldm.data.local.LocalDanbooruBase target: ldm.data.localdanboorubase.LocalDanbooruBase
params: params:
data_root: '../dataset'
size: 512 size: 512
mode: "val" mode: "val"
val_split: 64 val_split: 64
@ -109,9 +113,11 @@ lightning:
plot_diffusion_rows: False plot_diffusion_rows: False
N: 4 N: 4
ddim_steps: 50 ddim_steps: 50
trainer:
trainer: precision: 16
amp_backend: "native"
strategy: "fsdp"
benchmark: True benchmark: True
val_check_interval: 5000000 limit_val_batches: 0
num_sanity_val_steps: 0 num_sanity_val_steps: 0
accumulate_grad_batches: 1 accumulate_grad_batches: 1

View File

@ -0,0 +1,182 @@
import os
import numpy as np
import PIL
from PIL import Image, ImageOps
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms.functional as TF
from functools import partial
import copy
import glob
import random
PIL.Image.MAX_IMAGE_PIXELS = 933120000
import torchvision
import pytorch_lightning as pl
import torch
import re
import json
import io
def resize_image(image: Image, max_size=(768,768)):
image = ImageOps.contain(image, max_size, Image.LANCZOS)
# resize to integer multiple of 64
w, h = image.size
w, h = map(lambda x: x - x % 64, (w, h))
ratio = w / h
src_ratio = image.width / image.height
src_w = w if ratio > src_ratio else image.width * h // image.height
src_h = h if ratio <= src_ratio else image.height * w // image.width
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
res = Image.new("RGB", (w, h))
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
return res
class CaptionProcessor(object):
def __init__(self, transforms, max_size, resize, random_order, LR_size):
self.transforms = transforms
self.max_size = max_size
self.resize = resize
self.random_order = random_order
self.degradation_process = partial(TF.resize, size=LR_size, interpolation=TF.InterpolationMode.NEAREST)
def __call__(self, sample):
# preprocess caption
pass
# preprocess image
image = sample['image']
image = Image.open(io.BytesIO(image))
if self.resize:
image = resize_image(image, max_size=(self.max_size, self.max_size))
image = self.transforms(image)
lr_image = copy.deepcopy(image)
image = np.array(image).astype(np.uint8)
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
# preprocess LR image
lr_image = self.degradation_process(lr_image)
lr_image = np.array(lr_image).astype(np.uint8)
sample['LR_image'] = (lr_image/127.5 - 1.0).astype(np.float32)
return sample
class LocalDanbooruBaseVAE(Dataset):
def __init__(self,
data_root='./danbooru-aesthetic',
size=256,
interpolation="bicubic",
flip_p=0.5,
crop=True,
shuffle=False,
mode='train',
val_split=64,
downscale_f=8
):
super().__init__()
self.shuffle=shuffle
self.crop = crop
print('Fetching data.')
ext = ['image']
self.image_files = []
[self.image_files.extend(glob.glob(f'{data_root}' + '/*.' + e)) for e in ext]
if mode == 'val':
self.image_files = self.image_files[:len(self.image_files)//val_split]
print(f'Constructing image map. Found {len(self.image_files)} images')
self.examples = {}
self.hashes = []
for i in self.image_files:
hash = i[len(f'{data_root}/'):].split('.')[0]
self.examples[hash] = {
'image': i
}
self.hashes.append(hash)
print(f'image map has {len(self.examples.keys())} examples')
self.size = size
self.interpolation = {"linear": PIL.Image.LINEAR,
"bilinear": PIL.Image.BILINEAR,
"bicubic": PIL.Image.BICUBIC,
"lanczos": PIL.Image.LANCZOS,
}[interpolation]
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
image_transforms = []
image_transforms.extend([torchvision.transforms.RandomHorizontalFlip(flip_p)],)
image_transforms = torchvision.transforms.Compose(image_transforms)
self.captionprocessor = CaptionProcessor(image_transforms, self.size, True, True, int(size / downscale_f))
def random_sample(self):
return self.__getitem__(random.randint(0, self.__len__() - 1))
def sequential_sample(self, i):
if i >= self.__len__() - 1:
return self.__getitem__(0)
return self.__getitem__(i + 1)
def skip_sample(self, i):
return None
def __len__(self):
return len(self.image_files)
def __getitem__(self, i):
return self.get_image(i)
def get_image(self, i):
image = {}
try:
image_file = self.examples[self.hashes[i]]['image']
with open(image_file, 'rb') as f:
image['image'] = f.read()
image = self.captionprocessor(image)
except Exception as e:
print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}')
return self.skip_sample(i)
return image
"""
if __name__ == "__main__":
dataset = LocalBase('./danbooru-aesthetic', size=512, crop=False, mode='val')
print(dataset.__len__())
example = dataset.__getitem__(0)
print(dataset.hashes[0])
print(example['caption'])
image = example['image']
image = ((image + 1) * 127.5).astype(np.uint8)
image = Image.fromarray(image)
image.save('example.png')
"""
"""
from tqdm import tqdm
if __name__ == "__main__":
dataset = LocalDanbooruBase('./links', size=768)
import time
a = time.process_time()
for i in range(8):
example = dataset.get_image(i)
image = example['image']
image = ((image + 1) * 127.5).astype(np.uint8)
image = Image.fromarray(image)
image.save(f'example-{i}.png')
print(example['caption'])
print('time:', time.process_time()-a)
"""

View File

@ -340,6 +340,7 @@ class DDPM(pl.LightningModule):
return loss, loss_dict return loss, loss_dict
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
with torch.autocast('cuda'):
loss, loss_dict = self.shared_step(batch) loss, loss_dict = self.shared_step(batch)
self.log_dict(loss_dict, prog_bar=True, self.log_dict(loss_dict, prog_bar=True,
@ -475,7 +476,7 @@ class LatentDiffusion(DDPM):
@rank_zero_only @rank_zero_only
@torch.no_grad() @torch.no_grad()
def on_train_batch_start(self, batch, batch_idx, dataloader_idx): def on_train_batch_start(self, batch, batch_idx):
# only for very first batch # only for very first batch
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'

View File

@ -119,6 +119,7 @@ def checkpoint(func, inputs, params, flag):
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, run_function, length, *args): def forward(ctx, run_function, length, *args):
with torch.autocast('cuda'):
ctx.run_function = run_function ctx.run_function = run_function
ctx.input_tensors = list(args[:length]) ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:]) ctx.input_params = list(args[length:])
@ -129,6 +130,7 @@ class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, *output_grads): def backward(ctx, *output_grads):
with torch.autocast('cuda'):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad(): with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the # Fixes a bug where the first op in run_function modifies the

View File

@ -136,8 +136,7 @@ class SpatialRescaler(nn.Module):
return self(x) return self(x)
class FrozenCLIPEmbedder(AbstractEncoder): class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from Hugging Face)""" def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, extended_mode=None):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, max_chunks=3, extended_mode=True):
super().__init__() super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version) self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version)
@ -145,7 +144,6 @@ class FrozenCLIPEmbedder(AbstractEncoder):
self.max_length = max_length self.max_length = max_length
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
self.extended_mode = extended_mode self.extended_mode = extended_mode
self.max_chunks = max_chunks
self.freeze() self.freeze()
def freeze(self): def freeze(self):
@ -168,7 +166,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
if self.extended_mode: if self.extended_mode:
max_standard_tokens = self.max_length - 2 max_standard_tokens = self.max_length - 2
batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.max_chunks) - (self.max_chunks * 2), return_length=True, return_overflowing_tokens=False, padding=False, batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.extended_mode) - (self.extended_mode * 2), return_length=True, return_overflowing_tokens=False, padding=False,
add_special_tokens=False) add_special_tokens=False)
# get the max length aligned to chunk size. # get the max length aligned to chunk size.

15
main.py
View File

@ -295,7 +295,7 @@ class ImageLogger(Callback):
self.batch_freq = batch_frequency self.batch_freq = batch_frequency
self.max_images = max_images self.max_images = max_images
self.logger_log_images = { self.logger_log_images = {
pl.loggers.TestTubeLogger: self._testtube, pl.loggers.WandbLogger: self._testtube,
} }
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps: if not increase_log_steps:
@ -350,6 +350,7 @@ class ImageLogger(Callback):
pl_module.eval() pl_module.eval()
with torch.no_grad(): with torch.no_grad():
with torch.autocast('cuda'):
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images: for k in images:
@ -380,7 +381,7 @@ class ImageLogger(Callback):
return True return True
return False return False
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train") self.log_img(pl_module, batch, batch_idx, split="train")
@ -518,7 +519,7 @@ if __name__ == "__main__":
# merge trainer cli with config # merge trainer cli with config
trainer_config = lightning_config.get("trainer", OmegaConf.create()) trainer_config = lightning_config.get("trainer", OmegaConf.create())
# default to ddp # default to ddp
trainer_config["accelerator"] = "ddp" trainer_config["accelerator"] = "gpu"
for k in nondefault_trainer_args(opt): for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k) trainer_config[k] = getattr(opt, k)
if not "gpus" in trainer_config: if not "gpus" in trainer_config:
@ -556,7 +557,7 @@ if __name__ == "__main__":
} }
}, },
} }
default_logger_cfg = default_logger_cfgs["testtube"] default_logger_cfg = default_logger_cfgs["wandb"]
if "logger" in lightning_config: if "logger" in lightning_config:
logger_cfg = lightning_config.logger logger_cfg = lightning_config.logger
else: else:
@ -656,9 +657,11 @@ if __name__ == "__main__":
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer_kwargs["plugins"] = list() trainer_kwargs["plugins"] = list()
from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.plugins import DDPPlugin, NativeMixedPrecisionPlugin
trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False)) #trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False))
trainer_kwargs["plugins"].append(NativeMixedPrecisionPlugin(16, 'cuda', torch.cuda.amp.GradScaler(enabled=True)))
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
#trainer = Trainer(gpus=1, precision=16, amp_backend="native", strategy="deepspeed_stage_2_offload", benchmark=True, limit_val_batches=0, num_sanity_val_steps=0, accumulate_grad_batches=1)
trainer.logdir = logdir ### trainer.logdir = logdir ###
# data # data

View File

@ -4,7 +4,7 @@ opencv-python
pudb==2019.2 pudb==2019.2
imageio==2.9.0 imageio==2.9.0
imageio-ffmpeg==0.4.2 imageio-ffmpeg==0.4.2
pytorch-lightning==1.6.0 pytorch-lightning==1.7.7
omegaconf==2.1.1 omegaconf==2.1.1
test-tube>=0.7.5 test-tube>=0.7.5
streamlit>=0.73.1 streamlit>=0.73.1
@ -18,3 +18,5 @@ git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-tra
git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/openai/CLIP.git@main#egg=clip
git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
webdataset webdataset
wandb
fairscale