fix some quality issues

This commit is contained in:
Victor Hall 2022-12-20 03:30:42 -05:00
parent 07dbf64ecf
commit 8904724135
3 changed files with 68 additions and 57 deletions

View File

@ -87,51 +87,30 @@ class EveryDreamBatch(Dataset):
return self._length
def __getitem__(self, i):
#print(" * Getting item", i)
# batch = dict()
# batch["images"] = list()
# batch["captions"] = list()
# first = True
# for j in range(i, i + self.batch_size - 1):
# if j < self.num_images:
# example = self.__get_image_for_trainer(self.image_train_items[j], self.debug_level)
# if first:
# print(f"first example {j}", example)
# batch["images"] = [torch.from_numpy(example["image"])]
# batch["captions"] = [example["caption"]]
# first = False
# else:
# print(f"subsiquent example {j}", example)
# batch["images"].extend(torch.from_numpy(example["image"]))
# batch["captions"].extend(example["caption"])
example = {}
train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level)
#example["image"] = torch.from_numpy(train_item["image"])
example["image"] = self.image_transforms(train_item["image"])
# if train_item["caption"] == " ":
# example["tokens"] = [0 for i in range(self.max_token_length-2)]
#if random.random() > 9999:
example["tokens"] = self.tokenizer(train_item["caption"],
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
).input_ids
#print(example["tokens"])
example["tokens"] = torch.tensor(example["tokens"])
# else:
if random.random() > self.conditional_dropout:
example["tokens"] = self.tokenizer(train_item["caption"],
#padding="max_length",
truncation=True,
padding=False,
add_special_tokens=False,
max_length=self.max_token_length-2,
).input_ids
example["tokens"] = torch.tensor(example["tokens"])
else:
example["tokens"] = torch.zeros(75, dtype=torch.int)
# example["tokens"] = torch.zeros(75, dtype=torch.int)
#print(f"bos: {self.tokenizer.bos_token_id}{self.tokenizer.eos_token_id}")
#print(f"example['tokens']: {example['tokens']}")
pad_amt = self.max_token_length-2 - len(example["tokens"])
example['tokens']= F.pad(example['tokens'],pad=(0,pad_amt),mode='constant',value=0)
example['tokens']= F.pad(example['tokens'],pad=(1,0),mode='constant',value=int(self.tokenizer.bos_token_id))
eos_int = int(self.tokenizer.eos_token_id)
#pad_amt = self.max_token_length-2 - len(example["tokens"])
#example['tokens']= F.pad(example['tokens'],pad=(0,pad_amt),mode='constant',value=0)
#example['tokens']= F.pad(example['tokens'],pad=(1,0),mode='constant',value=int(self.tokenizer.bos_token_id))
#eos_int = int(self.tokenizer.eos_token_id)
#eos_int = int(0)
example['tokens']= F.pad(example['tokens'],pad=(0,1),mode='constant',value=eos_int)
#example['tokens']= F.pad(example['tokens'],pad=(0,1),mode='constant',value=eos_int)
#print(f"__getitem__ train_item['caption']: {train_item['caption']}")
#print(f"__getitem__ train_item['pathname']: {train_item['pathname']}")
#print(f"__getitem__ example['tokens'] pad: {example['tokens']}")
@ -149,6 +128,9 @@ class EveryDreamBatch(Dataset):
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
example["image"] = image_train_tmp.image
example["caption"] = image_train_tmp.caption
if random.random() > self.conditional_dropout:
example["caption"] = image_train_tmp.caption
else:
example["caption"] = " "
return example

View File

@ -19,6 +19,7 @@ from torchvision import transforms, utils
import random
import math
import os
import logging
_RANDOM_TRIM = 0.04
@ -48,7 +49,9 @@ class ImageTrainItem():
save: save the cropped image to disk, for manual inspection of resize/crop
crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality
"""
if not hasattr(self, 'image') or len(self.image) == 0:
#print(self.pathname, self.image)
try:
#if not hasattr(self, 'image'):
self.image = PIL.Image.open(self.pathname).convert('RGB')
width, height = self.image.size
@ -96,6 +99,10 @@ class ImageTrainItem():
self.image = self.image.resize(self.target_wh, resample=PIL.Image.BICUBIC)
self.image = self.flip(self.image)
except Exception as e:
logging.error(f"Error loading image: {self.pathname}")
print(e)
exit()
if type(self.image) is not np.ndarray:
if save:
@ -106,7 +113,7 @@ class ImageTrainItem():
self.image = np.array(self.image).astype(np.uint8)
self.image = (self.image / 127.5 - 1.0).astype(np.float32)
#self.image = (self.image / 127.5 - 1.0).astype(np.float32)
#print(self.image.shape)

View File

@ -20,6 +20,7 @@ import signal
import argparse
import logging
import time
import gc
import torch.nn.functional as torch_functional
from torch.cuda.amp import autocast
@ -49,6 +50,8 @@ from data.every_dream import EveryDreamBatch
from utils.convert_diffusers_to_stable_diffusion import convert as converter
from utils.gpu import GPU
import debug
_GRAD_ACCUM_STEPS = 1 # future use...
_SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9
@ -159,6 +162,11 @@ def main(args):
seed = 555
set_seed(seed)
gpu = GPU()
torch.backends.cudnn.benchmark = False
args.clip_skip = max(min(2, args.clip_skip), 0)
if args.text_encoder_epochs is None or args.text_encoder_epochs < 1:
args.text_encoder_epochs = _VERY_LARGE_NUMBER
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(" no checkpointing specified, defaulting to 20 minutes")
@ -279,7 +287,7 @@ def main(args):
i = 0
for prompt in prompts:
if prompt is None or len(prompt) < 2:
logging.warning("empty prompt in sample prompts, check your prompts file")
#logging.warning("empty prompt in sample prompts, check your prompts file")
continue
images = []
for cfg in [7.0, 4.0, 1.01]:
@ -301,14 +309,16 @@ def main(args):
clean_prompt = clean_filename(prompt)
result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.png")
result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.txt", "w") as f:
f.write(prompt)
tfimage = transforms.ToTensor()(result)
if random_captions:
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
i += 1
else:
log_writer.add_image(tag=f"sample_{clean_prompt[:150]}", img_tensor=tfimage, global_step=gs)
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
i += 1
del result
del tfimage
@ -406,7 +416,7 @@ def main(args):
for line in f:
sample_prompts.append(line.strip())
log_folder = os.path.join("logs", f"{args.project_name}{log_time}")
log_folder = os.path.join(args.logdir, f"{args.project_name}{log_time}")
logging.info(f"Logging to {log_folder}")
if False: #args.wandb is not None and args.wandb: # not yet supported
@ -428,7 +438,7 @@ def main(args):
log_args(log_writer, args)
args.clip_skip = max(min(2, args.clip_skip), 0)
"""
Train the model
@ -547,17 +557,20 @@ def main(args):
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
if epoch == args.text_encoder_epochs:
logging.info(f" Freezing text_encoder at epoch: {epoch}")
text_encoder.requires_grad_(False)
text_encoder.eval()
epoch_start_time = time.time()
steps_pbar.reset()
images_per_sec_epoch = []
for step, batch in enumerate(train_dataloader):
if args.debug:
debug.debug_batch(batch)
step_start_time = time.time()
if global_step > 0 and global_step > args.text_encoder_steps == 0:
text_encoder.requires_grad_(False)
text_encoder.eval()
with torch.no_grad():
with autocast():
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
@ -576,10 +589,17 @@ def main(args):
cuda_caption = batch["tokens"].to(text_encoder.device)
with autocast(enabled=args.amp):
encoder_hidden_states = text_encoder(cuda_caption)
#encoder_hidden_states = text_encoder(cuda_caption)
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
# if clip_skip > 0: #TODO
# encoder_hidden_states = encoder_hidden_states['last_hidden_state'][-clip_skip]
#print("encoder_hidden_states.keys()", encoder_hidden_states.keys())
#print("encoder_hidden_states.hidden_states.shape", encoder_hidden_states.hidden_states[0].shape)
#print("encoder_hidden_states.last_hidden_state.shape", encoder_hidden_states.last_hidden_state[0].shape)
if args.clip_skip > 0: # TODO
encoder_hidden_states = encoder_hidden_states.hidden_states[-args.clip_skip]
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
@ -592,7 +612,7 @@ def main(args):
#del noise, latents
with autocast(): # xformers requires autocast
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states.last_hidden_state).sample
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
#with autocast(enabled=args.amp):
loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
@ -607,8 +627,8 @@ def main(args):
loss.backward()
#scaler.unscale_(optimizer)
#if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=2)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=2)
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=1)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=1)
#scaler.step(optimizer)
#scaler.update()
optimizer.step()
@ -654,6 +674,7 @@ def main(args):
del pipe
torch.cuda.empty_cache()
gc.collect()
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
@ -714,7 +735,7 @@ if __name__ == "__main__":
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
argparser.add_argument("--grad_accum", type=int, default=1, help="NONFUNCTIONING. Gradient accumulation factor (def: 1), (ex, 2)")
argparser.add_argument("--clip_skip", type=int, default=0, help="NONFUNCTIONING. Train using penultimate layers (def: 0)", choices=[0, 1, 2])
argparser.add_argument("--clip_skip", type=int, default=2, help="Train using penultimate layer (def: 2)", choices=[0, 1, 2])
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
argparser.add_argument("--mixed_precision", default="no", help="NONFUNCTIONING. precision, (default: NO for fp32)", choices=["NO", "fp16", "bf16"])
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
@ -724,7 +745,8 @@ if __name__ == "__main__":
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--text_encoder_steps", type=int, default=0, help="disable text encoder training after N steps (def: disabled)")
argparser.add_argument("--text_encoder_epochs", type=int, default=0, help="disable text encoder training after N steps (def: disabled)")
argparser.add_argument("--debug", action="store_true", default=False)
args = argparser.parse_args()
main(args)