add fp16 training
This commit is contained in:
parent
9581fbc226
commit
f980137430
|
@ -47,6 +47,7 @@ model:
|
||||||
params:
|
params:
|
||||||
embed_dim: 4
|
embed_dim: 4
|
||||||
monitor: val/rec_loss
|
monitor: val/rec_loss
|
||||||
|
ckpt_path: "../latent-diffusion/logs/original/checkpoints/last.ckpt"
|
||||||
ddconfig:
|
ddconfig:
|
||||||
double_z: true
|
double_z: true
|
||||||
z_channels: 4
|
z_channels: 4
|
||||||
|
@ -69,22 +70,25 @@ model:
|
||||||
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
params:
|
params:
|
||||||
penultimate: true # use 2nd last layer - https://arxiv.org/pdf/2205.11487.pdf D.1
|
penultimate: true # use 2nd last layer - https://arxiv.org/pdf/2205.11487.pdf D.1
|
||||||
|
extended_mode: 3 # extend clip context to 225 tokens - as per NAI blogpost
|
||||||
|
|
||||||
data:
|
data:
|
||||||
target: main.DataModuleFromConfig
|
target: main.DataModuleFromConfig
|
||||||
params:
|
params:
|
||||||
batch_size: 4
|
batch_size: 2
|
||||||
num_workers: 4
|
num_workers: 2
|
||||||
wrap: false
|
wrap: false
|
||||||
train:
|
train:
|
||||||
target: ldm.data.local.LocalDanbooruBase
|
target: ldm.data.localdanboorubase.LocalDanbooruBase
|
||||||
params:
|
params:
|
||||||
|
data_root: '../dataset'
|
||||||
size: 512
|
size: 512
|
||||||
mode: "train"
|
mode: "train"
|
||||||
ucg: 0.1 # unconditional guidance training
|
ucg: 0.1 # unconditional guidance training
|
||||||
validation:
|
validation:
|
||||||
target: ldm.data.local.LocalDanbooruBase
|
target: ldm.data.localdanboorubase.LocalDanbooruBase
|
||||||
params:
|
params:
|
||||||
|
data_root: '../dataset'
|
||||||
size: 512
|
size: 512
|
||||||
mode: "val"
|
mode: "val"
|
||||||
val_split: 64
|
val_split: 64
|
||||||
|
@ -109,9 +113,11 @@ lightning:
|
||||||
plot_diffusion_rows: False
|
plot_diffusion_rows: False
|
||||||
N: 4
|
N: 4
|
||||||
ddim_steps: 50
|
ddim_steps: 50
|
||||||
|
trainer:
|
||||||
trainer:
|
precision: 16
|
||||||
|
amp_backend: "native"
|
||||||
|
strategy: "fsdp"
|
||||||
benchmark: True
|
benchmark: True
|
||||||
val_check_interval: 5000000
|
limit_val_batches: 0
|
||||||
num_sanity_val_steps: 0
|
num_sanity_val_steps: 0
|
||||||
accumulate_grad_batches: 1
|
accumulate_grad_batches: 1
|
||||||
|
|
|
@ -0,0 +1,182 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
import copy
|
||||||
|
|
||||||
|
import glob
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
||||||
|
import torchvision
|
||||||
|
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import io
|
||||||
|
|
||||||
|
def resize_image(image: Image, max_size=(768,768)):
|
||||||
|
image = ImageOps.contain(image, max_size, Image.LANCZOS)
|
||||||
|
# resize to integer multiple of 64
|
||||||
|
w, h = image.size
|
||||||
|
w, h = map(lambda x: x - x % 64, (w, h))
|
||||||
|
|
||||||
|
ratio = w / h
|
||||||
|
src_ratio = image.width / image.height
|
||||||
|
|
||||||
|
src_w = w if ratio > src_ratio else image.width * h // image.height
|
||||||
|
src_h = h if ratio <= src_ratio else image.height * w // image.width
|
||||||
|
|
||||||
|
resized = image.resize((src_w, src_h), resample=Image.LANCZOS)
|
||||||
|
res = Image.new("RGB", (w, h))
|
||||||
|
res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2))
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
||||||
|
class CaptionProcessor(object):
|
||||||
|
def __init__(self, transforms, max_size, resize, random_order, LR_size):
|
||||||
|
self.transforms = transforms
|
||||||
|
self.max_size = max_size
|
||||||
|
self.resize = resize
|
||||||
|
self.random_order = random_order
|
||||||
|
self.degradation_process = partial(TF.resize, size=LR_size, interpolation=TF.InterpolationMode.NEAREST)
|
||||||
|
|
||||||
|
def __call__(self, sample):
|
||||||
|
# preprocess caption
|
||||||
|
pass
|
||||||
|
|
||||||
|
# preprocess image
|
||||||
|
image = sample['image']
|
||||||
|
image = Image.open(io.BytesIO(image))
|
||||||
|
if self.resize:
|
||||||
|
image = resize_image(image, max_size=(self.max_size, self.max_size))
|
||||||
|
image = self.transforms(image)
|
||||||
|
lr_image = copy.deepcopy(image)
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
sample['image'] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
# preprocess LR image
|
||||||
|
lr_image = self.degradation_process(lr_image)
|
||||||
|
lr_image = np.array(lr_image).astype(np.uint8)
|
||||||
|
sample['LR_image'] = (lr_image/127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
class LocalDanbooruBaseVAE(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
data_root='./danbooru-aesthetic',
|
||||||
|
size=256,
|
||||||
|
interpolation="bicubic",
|
||||||
|
flip_p=0.5,
|
||||||
|
crop=True,
|
||||||
|
shuffle=False,
|
||||||
|
mode='train',
|
||||||
|
val_split=64,
|
||||||
|
downscale_f=8
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.shuffle=shuffle
|
||||||
|
self.crop = crop
|
||||||
|
|
||||||
|
print('Fetching data.')
|
||||||
|
|
||||||
|
ext = ['image']
|
||||||
|
self.image_files = []
|
||||||
|
[self.image_files.extend(glob.glob(f'{data_root}' + '/*.' + e)) for e in ext]
|
||||||
|
if mode == 'val':
|
||||||
|
self.image_files = self.image_files[:len(self.image_files)//val_split]
|
||||||
|
|
||||||
|
print(f'Constructing image map. Found {len(self.image_files)} images')
|
||||||
|
|
||||||
|
self.examples = {}
|
||||||
|
self.hashes = []
|
||||||
|
for i in self.image_files:
|
||||||
|
hash = i[len(f'{data_root}/'):].split('.')[0]
|
||||||
|
self.examples[hash] = {
|
||||||
|
'image': i
|
||||||
|
}
|
||||||
|
self.hashes.append(hash)
|
||||||
|
|
||||||
|
print(f'image map has {len(self.examples.keys())} examples')
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||||
|
"bilinear": PIL.Image.BILINEAR,
|
||||||
|
"bicubic": PIL.Image.BICUBIC,
|
||||||
|
"lanczos": PIL.Image.LANCZOS,
|
||||||
|
}[interpolation]
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
image_transforms = []
|
||||||
|
image_transforms.extend([torchvision.transforms.RandomHorizontalFlip(flip_p)],)
|
||||||
|
image_transforms = torchvision.transforms.Compose(image_transforms)
|
||||||
|
|
||||||
|
self.captionprocessor = CaptionProcessor(image_transforms, self.size, True, True, int(size / downscale_f))
|
||||||
|
|
||||||
|
def random_sample(self):
|
||||||
|
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
||||||
|
|
||||||
|
def sequential_sample(self, i):
|
||||||
|
if i >= self.__len__() - 1:
|
||||||
|
return self.__getitem__(0)
|
||||||
|
return self.__getitem__(i + 1)
|
||||||
|
|
||||||
|
def skip_sample(self, i):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_files)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.get_image(i)
|
||||||
|
|
||||||
|
def get_image(self, i):
|
||||||
|
image = {}
|
||||||
|
try:
|
||||||
|
image_file = self.examples[self.hashes[i]]['image']
|
||||||
|
with open(image_file, 'rb') as f:
|
||||||
|
image['image'] = f.read()
|
||||||
|
image = self.captionprocessor(image)
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}')
|
||||||
|
return self.skip_sample(i)
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
"""
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = LocalBase('./danbooru-aesthetic', size=512, crop=False, mode='val')
|
||||||
|
print(dataset.__len__())
|
||||||
|
example = dataset.__getitem__(0)
|
||||||
|
print(dataset.hashes[0])
|
||||||
|
print(example['caption'])
|
||||||
|
image = example['image']
|
||||||
|
image = ((image + 1) * 127.5).astype(np.uint8)
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
image.save('example.png')
|
||||||
|
"""
|
||||||
|
"""
|
||||||
|
from tqdm import tqdm
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = LocalDanbooruBase('./links', size=768)
|
||||||
|
import time
|
||||||
|
a = time.process_time()
|
||||||
|
for i in range(8):
|
||||||
|
example = dataset.get_image(i)
|
||||||
|
image = example['image']
|
||||||
|
image = ((image + 1) * 127.5).astype(np.uint8)
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
image.save(f'example-{i}.png')
|
||||||
|
print(example['caption'])
|
||||||
|
print('time:', time.process_time()-a)
|
||||||
|
"""
|
|
@ -340,6 +340,7 @@ class DDPM(pl.LightningModule):
|
||||||
return loss, loss_dict
|
return loss, loss_dict
|
||||||
|
|
||||||
def training_step(self, batch, batch_idx):
|
def training_step(self, batch, batch_idx):
|
||||||
|
with torch.autocast('cuda'):
|
||||||
loss, loss_dict = self.shared_step(batch)
|
loss, loss_dict = self.shared_step(batch)
|
||||||
|
|
||||||
self.log_dict(loss_dict, prog_bar=True,
|
self.log_dict(loss_dict, prog_bar=True,
|
||||||
|
@ -475,7 +476,7 @@ class LatentDiffusion(DDPM):
|
||||||
|
|
||||||
@rank_zero_only
|
@rank_zero_only
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
|
def on_train_batch_start(self, batch, batch_idx):
|
||||||
# only for very first batch
|
# only for very first batch
|
||||||
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
|
if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt:
|
||||||
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
|
assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously'
|
||||||
|
|
|
@ -119,6 +119,7 @@ def checkpoint(func, inputs, params, flag):
|
||||||
class CheckpointFunction(torch.autograd.Function):
|
class CheckpointFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, run_function, length, *args):
|
def forward(ctx, run_function, length, *args):
|
||||||
|
with torch.autocast('cuda'):
|
||||||
ctx.run_function = run_function
|
ctx.run_function = run_function
|
||||||
ctx.input_tensors = list(args[:length])
|
ctx.input_tensors = list(args[:length])
|
||||||
ctx.input_params = list(args[length:])
|
ctx.input_params = list(args[length:])
|
||||||
|
@ -129,6 +130,7 @@ class CheckpointFunction(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, *output_grads):
|
def backward(ctx, *output_grads):
|
||||||
|
with torch.autocast('cuda'):
|
||||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||||
with torch.enable_grad():
|
with torch.enable_grad():
|
||||||
# Fixes a bug where the first op in run_function modifies the
|
# Fixes a bug where the first op in run_function modifies the
|
||||||
|
|
|
@ -136,8 +136,7 @@ class SpatialRescaler(nn.Module):
|
||||||
return self(x)
|
return self(x)
|
||||||
|
|
||||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
"""Uses the CLIP transformer encoder for text (from Hugging Face)"""
|
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, extended_mode=None):
|
||||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, max_chunks=3, extended_mode=True):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||||
|
@ -145,7 +144,6 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
|
self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf
|
||||||
self.extended_mode = extended_mode
|
self.extended_mode = extended_mode
|
||||||
self.max_chunks = max_chunks
|
|
||||||
self.freeze()
|
self.freeze()
|
||||||
|
|
||||||
def freeze(self):
|
def freeze(self):
|
||||||
|
@ -168,7 +166,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||||
if self.extended_mode:
|
if self.extended_mode:
|
||||||
max_standard_tokens = self.max_length - 2
|
max_standard_tokens = self.max_length - 2
|
||||||
|
|
||||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.max_chunks) - (self.max_chunks * 2), return_length=True, return_overflowing_tokens=False, padding=False,
|
batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.extended_mode) - (self.extended_mode * 2), return_length=True, return_overflowing_tokens=False, padding=False,
|
||||||
add_special_tokens=False)
|
add_special_tokens=False)
|
||||||
|
|
||||||
# get the max length aligned to chunk size.
|
# get the max length aligned to chunk size.
|
||||||
|
|
15
main.py
15
main.py
|
@ -295,7 +295,7 @@ class ImageLogger(Callback):
|
||||||
self.batch_freq = batch_frequency
|
self.batch_freq = batch_frequency
|
||||||
self.max_images = max_images
|
self.max_images = max_images
|
||||||
self.logger_log_images = {
|
self.logger_log_images = {
|
||||||
pl.loggers.TestTubeLogger: self._testtube,
|
pl.loggers.WandbLogger: self._testtube,
|
||||||
}
|
}
|
||||||
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
|
||||||
if not increase_log_steps:
|
if not increase_log_steps:
|
||||||
|
@ -350,6 +350,7 @@ class ImageLogger(Callback):
|
||||||
pl_module.eval()
|
pl_module.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
with torch.autocast('cuda'):
|
||||||
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
|
||||||
|
|
||||||
for k in images:
|
for k in images:
|
||||||
|
@ -380,7 +381,7 @@ class ImageLogger(Callback):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||||
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
|
||||||
self.log_img(pl_module, batch, batch_idx, split="train")
|
self.log_img(pl_module, batch, batch_idx, split="train")
|
||||||
|
|
||||||
|
@ -518,7 +519,7 @@ if __name__ == "__main__":
|
||||||
# merge trainer cli with config
|
# merge trainer cli with config
|
||||||
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
trainer_config = lightning_config.get("trainer", OmegaConf.create())
|
||||||
# default to ddp
|
# default to ddp
|
||||||
trainer_config["accelerator"] = "ddp"
|
trainer_config["accelerator"] = "gpu"
|
||||||
for k in nondefault_trainer_args(opt):
|
for k in nondefault_trainer_args(opt):
|
||||||
trainer_config[k] = getattr(opt, k)
|
trainer_config[k] = getattr(opt, k)
|
||||||
if not "gpus" in trainer_config:
|
if not "gpus" in trainer_config:
|
||||||
|
@ -556,7 +557,7 @@ if __name__ == "__main__":
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
default_logger_cfg = default_logger_cfgs["testtube"]
|
default_logger_cfg = default_logger_cfgs["wandb"]
|
||||||
if "logger" in lightning_config:
|
if "logger" in lightning_config:
|
||||||
logger_cfg = lightning_config.logger
|
logger_cfg = lightning_config.logger
|
||||||
else:
|
else:
|
||||||
|
@ -656,9 +657,11 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
|
||||||
trainer_kwargs["plugins"] = list()
|
trainer_kwargs["plugins"] = list()
|
||||||
from pytorch_lightning.plugins import DDPPlugin
|
from pytorch_lightning.plugins import DDPPlugin, NativeMixedPrecisionPlugin
|
||||||
trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False))
|
#trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False))
|
||||||
|
trainer_kwargs["plugins"].append(NativeMixedPrecisionPlugin(16, 'cuda', torch.cuda.amp.GradScaler(enabled=True)))
|
||||||
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
|
||||||
|
#trainer = Trainer(gpus=1, precision=16, amp_backend="native", strategy="deepspeed_stage_2_offload", benchmark=True, limit_val_batches=0, num_sanity_val_steps=0, accumulate_grad_batches=1)
|
||||||
trainer.logdir = logdir ###
|
trainer.logdir = logdir ###
|
||||||
|
|
||||||
# data
|
# data
|
||||||
|
|
|
@ -4,7 +4,7 @@ opencv-python
|
||||||
pudb==2019.2
|
pudb==2019.2
|
||||||
imageio==2.9.0
|
imageio==2.9.0
|
||||||
imageio-ffmpeg==0.4.2
|
imageio-ffmpeg==0.4.2
|
||||||
pytorch-lightning==1.6.0
|
pytorch-lightning==1.7.7
|
||||||
omegaconf==2.1.1
|
omegaconf==2.1.1
|
||||||
test-tube>=0.7.5
|
test-tube>=0.7.5
|
||||||
streamlit>=0.73.1
|
streamlit>=0.73.1
|
||||||
|
@ -18,3 +18,5 @@ git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-tra
|
||||||
git+https://github.com/openai/CLIP.git@main#egg=clip
|
git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
|
git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
|
||||||
webdataset
|
webdataset
|
||||||
|
wandb
|
||||||
|
fairscale
|
Loading…
Reference in New Issue