diff --git a/configs/stable-diffusion/v1-4-finetune-test.yaml b/configs/stable-diffusion/v1-4-finetune-test.yaml index eddbe34..e679b59 100644 --- a/configs/stable-diffusion/v1-4-finetune-test.yaml +++ b/configs/stable-diffusion/v1-4-finetune-test.yaml @@ -47,6 +47,7 @@ model: params: embed_dim: 4 monitor: val/rec_loss + ckpt_path: "../latent-diffusion/logs/original/checkpoints/last.ckpt" ddconfig: double_z: true z_channels: 4 @@ -69,22 +70,25 @@ model: target: ldm.modules.encoders.modules.FrozenCLIPEmbedder params: penultimate: true # use 2nd last layer - https://arxiv.org/pdf/2205.11487.pdf D.1 + extended_mode: 3 # extend clip context to 225 tokens - as per NAI blogpost data: target: main.DataModuleFromConfig params: - batch_size: 4 - num_workers: 4 + batch_size: 2 + num_workers: 2 wrap: false train: - target: ldm.data.local.LocalDanbooruBase + target: ldm.data.localdanboorubase.LocalDanbooruBase params: + data_root: '../dataset' size: 512 mode: "train" ucg: 0.1 # unconditional guidance training validation: - target: ldm.data.local.LocalDanbooruBase + target: ldm.data.localdanboorubase.LocalDanbooruBase params: + data_root: '../dataset' size: 512 mode: "val" val_split: 64 @@ -109,9 +113,11 @@ lightning: plot_diffusion_rows: False N: 4 ddim_steps: 50 - -trainer: - benchmark: True - val_check_interval: 5000000 - num_sanity_val_steps: 0 - accumulate_grad_batches: 1 + trainer: + precision: 16 + amp_backend: "native" + strategy: "fsdp" + benchmark: True + limit_val_batches: 0 + num_sanity_val_steps: 0 + accumulate_grad_batches: 1 diff --git a/ldm/data/localdanboorubasevae.py b/ldm/data/localdanboorubasevae.py new file mode 100644 index 0000000..a5cea0c --- /dev/null +++ b/ldm/data/localdanboorubasevae.py @@ -0,0 +1,182 @@ +import os +import numpy as np +import PIL +from PIL import Image, ImageOps +from torch.utils.data import Dataset +from torchvision import transforms +import torchvision.transforms.functional as TF + +from functools import partial +import copy + +import glob + +import random + +PIL.Image.MAX_IMAGE_PIXELS = 933120000 +import torchvision + +import pytorch_lightning as pl + +import torch + +import re +import json +import io + +def resize_image(image: Image, max_size=(768,768)): + image = ImageOps.contain(image, max_size, Image.LANCZOS) + # resize to integer multiple of 64 + w, h = image.size + w, h = map(lambda x: x - x % 64, (w, h)) + + ratio = w / h + src_ratio = image.width / image.height + + src_w = w if ratio > src_ratio else image.width * h // image.height + src_h = h if ratio <= src_ratio else image.height * w // image.width + + resized = image.resize((src_w, src_h), resample=Image.LANCZOS) + res = Image.new("RGB", (w, h)) + res.paste(resized, box=(w // 2 - src_w // 2, h // 2 - src_h // 2)) + + return res + +class CaptionProcessor(object): + def __init__(self, transforms, max_size, resize, random_order, LR_size): + self.transforms = transforms + self.max_size = max_size + self.resize = resize + self.random_order = random_order + self.degradation_process = partial(TF.resize, size=LR_size, interpolation=TF.InterpolationMode.NEAREST) + + def __call__(self, sample): + # preprocess caption + pass + + # preprocess image + image = sample['image'] + image = Image.open(io.BytesIO(image)) + if self.resize: + image = resize_image(image, max_size=(self.max_size, self.max_size)) + image = self.transforms(image) + lr_image = copy.deepcopy(image) + image = np.array(image).astype(np.uint8) + sample['image'] = (image / 127.5 - 1.0).astype(np.float32) + + # preprocess LR image + lr_image = self.degradation_process(lr_image) + lr_image = np.array(lr_image).astype(np.uint8) + sample['LR_image'] = (lr_image/127.5 - 1.0).astype(np.float32) + + return sample + +class LocalDanbooruBaseVAE(Dataset): + def __init__(self, + data_root='./danbooru-aesthetic', + size=256, + interpolation="bicubic", + flip_p=0.5, + crop=True, + shuffle=False, + mode='train', + val_split=64, + downscale_f=8 + ): + super().__init__() + + self.shuffle=shuffle + self.crop = crop + + print('Fetching data.') + + ext = ['image'] + self.image_files = [] + [self.image_files.extend(glob.glob(f'{data_root}' + '/*.' + e)) for e in ext] + if mode == 'val': + self.image_files = self.image_files[:len(self.image_files)//val_split] + + print(f'Constructing image map. Found {len(self.image_files)} images') + + self.examples = {} + self.hashes = [] + for i in self.image_files: + hash = i[len(f'{data_root}/'):].split('.')[0] + self.examples[hash] = { + 'image': i + } + self.hashes.append(hash) + + print(f'image map has {len(self.examples.keys())} examples') + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + image_transforms = [] + image_transforms.extend([torchvision.transforms.RandomHorizontalFlip(flip_p)],) + image_transforms = torchvision.transforms.Compose(image_transforms) + + self.captionprocessor = CaptionProcessor(image_transforms, self.size, True, True, int(size / downscale_f)) + + def random_sample(self): + return self.__getitem__(random.randint(0, self.__len__() - 1)) + + def sequential_sample(self, i): + if i >= self.__len__() - 1: + return self.__getitem__(0) + return self.__getitem__(i + 1) + + def skip_sample(self, i): + return None + + def __len__(self): + return len(self.image_files) + + def __getitem__(self, i): + return self.get_image(i) + + def get_image(self, i): + image = {} + try: + image_file = self.examples[self.hashes[i]]['image'] + with open(image_file, 'rb') as f: + image['image'] = f.read() + image = self.captionprocessor(image) + except Exception as e: + print(f'Error with {self.examples[self.hashes[i]]["image"]} -- {e} -- skipping {i}') + return self.skip_sample(i) + + return image + +""" +if __name__ == "__main__": + dataset = LocalBase('./danbooru-aesthetic', size=512, crop=False, mode='val') + print(dataset.__len__()) + example = dataset.__getitem__(0) + print(dataset.hashes[0]) + print(example['caption']) + image = example['image'] + image = ((image + 1) * 127.5).astype(np.uint8) + image = Image.fromarray(image) + image.save('example.png') +""" +""" +from tqdm import tqdm +if __name__ == "__main__": + dataset = LocalDanbooruBase('./links', size=768) + import time + a = time.process_time() + for i in range(8): + example = dataset.get_image(i) + image = example['image'] + image = ((image + 1) * 127.5).astype(np.uint8) + image = Image.fromarray(image) + image.save(f'example-{i}.png') + print(example['caption']) + print('time:', time.process_time()-a) +""" diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index 5107c32..4cb5651 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -340,17 +340,18 @@ class DDPM(pl.LightningModule): return loss, loss_dict def training_step(self, batch, batch_idx): - loss, loss_dict = self.shared_step(batch) + with torch.autocast('cuda'): + loss, loss_dict = self.shared_step(batch) - self.log_dict(loss_dict, prog_bar=True, - logger=True, on_step=True, on_epoch=True) + self.log_dict(loss_dict, prog_bar=True, + logger=True, on_step=True, on_epoch=True) - self.log("global_step", self.global_step, - prog_bar=True, logger=True, on_step=True, on_epoch=False) + self.log("global_step", self.global_step, + prog_bar=True, logger=True, on_step=True, on_epoch=False) - if self.use_scheduler: - lr = self.optimizers().param_groups[0]['lr'] - self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) + if self.use_scheduler: + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False) return loss @@ -475,7 +476,7 @@ class LatentDiffusion(DDPM): @rank_zero_only @torch.no_grad() - def on_train_batch_start(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, batch, batch_idx): # only for very first batch if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 and batch_idx == 0 and not self.restarted_from_ckpt: assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index a952e6c..9e514ca 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -119,33 +119,35 @@ def checkpoint(func, inputs, params, flag): class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, length, *args): - ctx.run_function = run_function - ctx.input_tensors = list(args[:length]) - ctx.input_params = list(args[length:]) + with torch.autocast('cuda'): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) - with torch.no_grad(): - output_tensors = ctx.run_function(*ctx.input_tensors) - return output_tensors + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors @staticmethod def backward(ctx, *output_grads): - ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] - with torch.enable_grad(): - # Fixes a bug where the first op in run_function modifies the - # Tensor storage in place, which is not allowed for detach()'d - # Tensors. - shallow_copies = [x.view_as(x) for x in ctx.input_tensors] - output_tensors = ctx.run_function(*shallow_copies) - input_grads = torch.autograd.grad( - output_tensors, - ctx.input_tensors + ctx.input_params, - output_grads, - allow_unused=True, - ) - del ctx.input_tensors - del ctx.input_params - del output_tensors - return (None, None) + input_grads + with torch.autocast('cuda'): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 17152ec..4d9f08c 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -136,8 +136,7 @@ class SpatialRescaler(nn.Module): return self(x) class FrozenCLIPEmbedder(AbstractEncoder): - """Uses the CLIP transformer encoder for text (from Hugging Face)""" - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, max_chunks=3, extended_mode=True): + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, penultimate=True, extended_mode=None): super().__init__() self.tokenizer = CLIPTokenizer.from_pretrained(version) self.transformer = CLIPTextModel.from_pretrained(version) @@ -145,7 +144,6 @@ class FrozenCLIPEmbedder(AbstractEncoder): self.max_length = max_length self.penultimate = penultimate # return embeddings from 2nd to last layer, see https://arxiv.org/pdf/2205.11487.pdf self.extended_mode = extended_mode - self.max_chunks = max_chunks self.freeze() def freeze(self): @@ -168,7 +166,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): if self.extended_mode: max_standard_tokens = self.max_length - 2 - batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.max_chunks) - (self.max_chunks * 2), return_length=True, return_overflowing_tokens=False, padding=False, + batch_encoding = self.tokenizer(text, truncation=True, max_length=(self.max_length * self.extended_mode) - (self.extended_mode * 2), return_length=True, return_overflowing_tokens=False, padding=False, add_special_tokens=False) # get the max length aligned to chunk size. diff --git a/main.py b/main.py index f32fce0..fe1701d 100644 --- a/main.py +++ b/main.py @@ -295,7 +295,7 @@ class ImageLogger(Callback): self.batch_freq = batch_frequency self.max_images = max_images self.logger_log_images = { - pl.loggers.TestTubeLogger: self._testtube, + pl.loggers.WandbLogger: self._testtube, } self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)] if not increase_log_steps: @@ -350,7 +350,8 @@ class ImageLogger(Callback): pl_module.eval() with torch.no_grad(): - images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) + with torch.autocast('cuda'): + images = pl_module.log_images(batch, split=split, **self.log_images_kwargs) for k in images: N = min(images[k].shape[0], self.max_images) @@ -380,7 +381,7 @@ class ImageLogger(Callback): return True return False - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): if not self.disabled and (pl_module.global_step > 0 or self.log_first_step): self.log_img(pl_module, batch, batch_idx, split="train") @@ -518,7 +519,7 @@ if __name__ == "__main__": # merge trainer cli with config trainer_config = lightning_config.get("trainer", OmegaConf.create()) # default to ddp - trainer_config["accelerator"] = "ddp" + trainer_config["accelerator"] = "gpu" for k in nondefault_trainer_args(opt): trainer_config[k] = getattr(opt, k) if not "gpus" in trainer_config: @@ -556,7 +557,7 @@ if __name__ == "__main__": } }, } - default_logger_cfg = default_logger_cfgs["testtube"] + default_logger_cfg = default_logger_cfgs["wandb"] if "logger" in lightning_config: logger_cfg = lightning_config.logger else: @@ -656,9 +657,11 @@ if __name__ == "__main__": trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg] trainer_kwargs["plugins"] = list() - from pytorch_lightning.plugins import DDPPlugin - trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False)) + from pytorch_lightning.plugins import DDPPlugin, NativeMixedPrecisionPlugin + #trainer_kwargs["plugins"].append(DDPPlugin(find_unused_parameters=False)) + trainer_kwargs["plugins"].append(NativeMixedPrecisionPlugin(16, 'cuda', torch.cuda.amp.GradScaler(enabled=True))) trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs) + #trainer = Trainer(gpus=1, precision=16, amp_backend="native", strategy="deepspeed_stage_2_offload", benchmark=True, limit_val_batches=0, num_sanity_val_steps=0, accumulate_grad_batches=1) trainer.logdir = logdir ### # data diff --git a/requirements.txt b/requirements.txt index b6dee37..82c04b2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ opencv-python pudb==2019.2 imageio==2.9.0 imageio-ffmpeg==0.4.2 -pytorch-lightning==1.6.0 +pytorch-lightning==1.7.7 omegaconf==2.1.1 test-tube>=0.7.5 streamlit>=0.73.1 @@ -17,4 +17,6 @@ gradio git+https://github.com/illeatmyhat/taming-transformers.git@master#egg=taming-transformers git+https://github.com/openai/CLIP.git@main#egg=clip git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion -webdataset \ No newline at end of file +webdataset +wandb +fairscale \ No newline at end of file