From 8904724135fd922770cd77a92e222f1138944847 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Tue, 20 Dec 2022 03:30:42 -0500 Subject: [PATCH] fix some quality issues --- data/every_dream.py | 56 +++++++++++++------------------------- data/image_train_item.py | 11 ++++++-- train.py | 58 +++++++++++++++++++++++++++------------- 3 files changed, 68 insertions(+), 57 deletions(-) diff --git a/data/every_dream.py b/data/every_dream.py index 1f1066e..1486269 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -87,51 +87,30 @@ class EveryDreamBatch(Dataset): return self._length def __getitem__(self, i): - #print(" * Getting item", i) - # batch = dict() - # batch["images"] = list() - # batch["captions"] = list() - # first = True - # for j in range(i, i + self.batch_size - 1): - # if j < self.num_images: - # example = self.__get_image_for_trainer(self.image_train_items[j], self.debug_level) - # if first: - # print(f"first example {j}", example) - # batch["images"] = [torch.from_numpy(example["image"])] - # batch["captions"] = [example["caption"]] - # first = False - # else: - # print(f"subsiquent example {j}", example) - # batch["images"].extend(torch.from_numpy(example["image"])) - # batch["captions"].extend(example["caption"]) example = {} train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level) - #example["image"] = torch.from_numpy(train_item["image"]) example["image"] = self.image_transforms(train_item["image"]) - # if train_item["caption"] == " ": - # example["tokens"] = [0 for i in range(self.max_token_length-2)] + + #if random.random() > 9999: + example["tokens"] = self.tokenizer(train_item["caption"], + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + ).input_ids + #print(example["tokens"]) + example["tokens"] = torch.tensor(example["tokens"]) # else: - if random.random() > self.conditional_dropout: - example["tokens"] = self.tokenizer(train_item["caption"], - #padding="max_length", - truncation=True, - padding=False, - add_special_tokens=False, - max_length=self.max_token_length-2, - ).input_ids - example["tokens"] = torch.tensor(example["tokens"]) - else: - example["tokens"] = torch.zeros(75, dtype=torch.int) + # example["tokens"] = torch.zeros(75, dtype=torch.int) #print(f"bos: {self.tokenizer.bos_token_id}{self.tokenizer.eos_token_id}") #print(f"example['tokens']: {example['tokens']}") - pad_amt = self.max_token_length-2 - len(example["tokens"]) - example['tokens']= F.pad(example['tokens'],pad=(0,pad_amt),mode='constant',value=0) - example['tokens']= F.pad(example['tokens'],pad=(1,0),mode='constant',value=int(self.tokenizer.bos_token_id)) - eos_int = int(self.tokenizer.eos_token_id) + #pad_amt = self.max_token_length-2 - len(example["tokens"]) + #example['tokens']= F.pad(example['tokens'],pad=(0,pad_amt),mode='constant',value=0) + #example['tokens']= F.pad(example['tokens'],pad=(1,0),mode='constant',value=int(self.tokenizer.bos_token_id)) + #eos_int = int(self.tokenizer.eos_token_id) #eos_int = int(0) - example['tokens']= F.pad(example['tokens'],pad=(0,1),mode='constant',value=eos_int) + #example['tokens']= F.pad(example['tokens'],pad=(0,1),mode='constant',value=eos_int) #print(f"__getitem__ train_item['caption']: {train_item['caption']}") #print(f"__getitem__ train_item['pathname']: {train_item['pathname']}") #print(f"__getitem__ example['tokens'] pad: {example['tokens']}") @@ -149,6 +128,9 @@ class EveryDreamBatch(Dataset): image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter) example["image"] = image_train_tmp.image - example["caption"] = image_train_tmp.caption + if random.random() > self.conditional_dropout: + example["caption"] = image_train_tmp.caption + else: + example["caption"] = " " return example diff --git a/data/image_train_item.py b/data/image_train_item.py index e9214b6..983ad35 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -19,6 +19,7 @@ from torchvision import transforms, utils import random import math import os +import logging _RANDOM_TRIM = 0.04 @@ -48,7 +49,9 @@ class ImageTrainItem(): save: save the cropped image to disk, for manual inspection of resize/crop crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality """ - if not hasattr(self, 'image') or len(self.image) == 0: + #print(self.pathname, self.image) + try: + #if not hasattr(self, 'image'): self.image = PIL.Image.open(self.pathname).convert('RGB') width, height = self.image.size @@ -96,6 +99,10 @@ class ImageTrainItem(): self.image = self.image.resize(self.target_wh, resample=PIL.Image.BICUBIC) self.image = self.flip(self.image) + except Exception as e: + logging.error(f"Error loading image: {self.pathname}") + print(e) + exit() if type(self.image) is not np.ndarray: if save: @@ -106,7 +113,7 @@ class ImageTrainItem(): self.image = np.array(self.image).astype(np.uint8) - self.image = (self.image / 127.5 - 1.0).astype(np.float32) + #self.image = (self.image / 127.5 - 1.0).astype(np.float32) #print(self.image.shape) diff --git a/train.py b/train.py index d66e49b..3ec3926 100644 --- a/train.py +++ b/train.py @@ -20,6 +20,7 @@ import signal import argparse import logging import time +import gc import torch.nn.functional as torch_functional from torch.cuda.amp import autocast @@ -49,6 +50,8 @@ from data.every_dream import EveryDreamBatch from utils.convert_diffusers_to_stable_diffusion import convert as converter from utils.gpu import GPU +import debug + _GRAD_ACCUM_STEPS = 1 # future use... _SIGTERM_EXIT_CODE = 130 _VERY_LARGE_NUMBER = 1e9 @@ -159,6 +162,11 @@ def main(args): seed = 555 set_seed(seed) gpu = GPU() + torch.backends.cudnn.benchmark = False + args.clip_skip = max(min(2, args.clip_skip), 0) + + if args.text_encoder_epochs is None or args.text_encoder_epochs < 1: + args.text_encoder_epochs = _VERY_LARGE_NUMBER if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None: logging.info(" no checkpointing specified, defaulting to 20 minutes") @@ -279,7 +287,7 @@ def main(args): i = 0 for prompt in prompts: if prompt is None or len(prompt) < 2: - logging.warning("empty prompt in sample prompts, check your prompts file") + #logging.warning("empty prompt in sample prompts, check your prompts file") continue images = [] for cfg in [7.0, 4.0, 1.01]: @@ -301,14 +309,16 @@ def main(args): clean_prompt = clean_filename(prompt) - result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.png") + result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False) + with open(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.txt", "w") as f: + f.write(prompt) tfimage = transforms.ToTensor()(result) if random_captions: log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs) - i += 1 else: - log_writer.add_image(tag=f"sample_{clean_prompt[:150]}", img_tensor=tfimage, global_step=gs) + log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs) + i += 1 del result del tfimage @@ -406,7 +416,7 @@ def main(args): for line in f: sample_prompts.append(line.strip()) - log_folder = os.path.join("logs", f"{args.project_name}{log_time}") + log_folder = os.path.join(args.logdir, f"{args.project_name}{log_time}") logging.info(f"Logging to {log_folder}") if False: #args.wandb is not None and args.wandb: # not yet supported @@ -428,7 +438,7 @@ def main(args): log_args(log_writer, args) - args.clip_skip = max(min(2, args.clip_skip), 0) + """ Train the model @@ -547,17 +557,20 @@ def main(args): save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}") __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir) + if epoch == args.text_encoder_epochs: + logging.info(f" Freezing text_encoder at epoch: {epoch}") + text_encoder.requires_grad_(False) + text_encoder.eval() + epoch_start_time = time.time() steps_pbar.reset() images_per_sec_epoch = [] for step, batch in enumerate(train_dataloader): + if args.debug: + debug.debug_batch(batch) step_start_time = time.time() - if global_step > 0 and global_step > args.text_encoder_steps == 0: - text_encoder.requires_grad_(False) - text_encoder.eval() - with torch.no_grad(): with autocast(): pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device) @@ -576,10 +589,17 @@ def main(args): cuda_caption = batch["tokens"].to(text_encoder.device) with autocast(enabled=args.amp): - encoder_hidden_states = text_encoder(cuda_caption) + #encoder_hidden_states = text_encoder(cuda_caption) + encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True) - # if clip_skip > 0: #TODO - # encoder_hidden_states = encoder_hidden_states['last_hidden_state'][-clip_skip] + #print("encoder_hidden_states.keys()", encoder_hidden_states.keys()) + #print("encoder_hidden_states.hidden_states.shape", encoder_hidden_states.hidden_states[0].shape) + #print("encoder_hidden_states.last_hidden_state.shape", encoder_hidden_states.last_hidden_state[0].shape) + + if args.clip_skip > 0: # TODO + encoder_hidden_states = encoder_hidden_states.hidden_states[-args.clip_skip] + else: + encoder_hidden_states = encoder_hidden_states.last_hidden_state noisy_latents = scheduler.add_noise(latents, noise, timesteps) @@ -592,7 +612,7 @@ def main(args): #del noise, latents with autocast(): # xformers requires autocast - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states.last_hidden_state).sample + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample #with autocast(enabled=args.amp): loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean") @@ -607,8 +627,8 @@ def main(args): loss.backward() #scaler.unscale_(optimizer) #if args.clip_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=2) - torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=2) + torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=1) + torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=1) #scaler.step(optimizer) #scaler.update() optimizer.step() @@ -654,6 +674,7 @@ def main(args): del pipe torch.cuda.empty_cache() + gc.collect() min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60 @@ -714,7 +735,7 @@ if __name__ == "__main__": argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)") argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?") argparser.add_argument("--grad_accum", type=int, default=1, help="NONFUNCTIONING. Gradient accumulation factor (def: 1), (ex, 2)") - argparser.add_argument("--clip_skip", type=int, default=0, help="NONFUNCTIONING. Train using penultimate layers (def: 0)", choices=[0, 1, 2]) + argparser.add_argument("--clip_skip", type=int, default=2, help="Train using penultimate layer (def: 2)", choices=[0, 1, 2]) argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are") argparser.add_argument("--mixed_precision", default="no", help="NONFUNCTIONING. precision, (default: NO for fp32)", choices=["NO", "fp16", "bf16"]) argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY") @@ -724,7 +745,8 @@ if __name__ == "__main__": argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)") argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)") argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)") - argparser.add_argument("--text_encoder_steps", type=int, default=0, help="disable text encoder training after N steps (def: disabled)") + argparser.add_argument("--text_encoder_epochs", type=int, default=0, help="disable text encoder training after N steps (def: disabled)") + argparser.add_argument("--debug", action="store_true", default=False) args = argparser.parse_args() main(args)