fix some quality issues
This commit is contained in:
parent
07dbf64ecf
commit
8904724135
|
@ -87,51 +87,30 @@ class EveryDreamBatch(Dataset):
|
|||
return self._length
|
||||
|
||||
def __getitem__(self, i):
|
||||
#print(" * Getting item", i)
|
||||
# batch = dict()
|
||||
# batch["images"] = list()
|
||||
# batch["captions"] = list()
|
||||
# first = True
|
||||
# for j in range(i, i + self.batch_size - 1):
|
||||
# if j < self.num_images:
|
||||
# example = self.__get_image_for_trainer(self.image_train_items[j], self.debug_level)
|
||||
# if first:
|
||||
# print(f"first example {j}", example)
|
||||
# batch["images"] = [torch.from_numpy(example["image"])]
|
||||
# batch["captions"] = [example["caption"]]
|
||||
# first = False
|
||||
# else:
|
||||
# print(f"subsiquent example {j}", example)
|
||||
# batch["images"].extend(torch.from_numpy(example["image"]))
|
||||
# batch["captions"].extend(example["caption"])
|
||||
example = {}
|
||||
|
||||
train_item = self.__get_image_for_trainer(self.image_train_items[i], self.debug_level)
|
||||
#example["image"] = torch.from_numpy(train_item["image"])
|
||||
example["image"] = self.image_transforms(train_item["image"])
|
||||
# if train_item["caption"] == " ":
|
||||
# example["tokens"] = [0 for i in range(self.max_token_length-2)]
|
||||
|
||||
#if random.random() > 9999:
|
||||
example["tokens"] = self.tokenizer(train_item["caption"],
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids
|
||||
#print(example["tokens"])
|
||||
example["tokens"] = torch.tensor(example["tokens"])
|
||||
# else:
|
||||
if random.random() > self.conditional_dropout:
|
||||
example["tokens"] = self.tokenizer(train_item["caption"],
|
||||
#padding="max_length",
|
||||
truncation=True,
|
||||
padding=False,
|
||||
add_special_tokens=False,
|
||||
max_length=self.max_token_length-2,
|
||||
).input_ids
|
||||
example["tokens"] = torch.tensor(example["tokens"])
|
||||
else:
|
||||
example["tokens"] = torch.zeros(75, dtype=torch.int)
|
||||
# example["tokens"] = torch.zeros(75, dtype=torch.int)
|
||||
#print(f"bos: {self.tokenizer.bos_token_id}{self.tokenizer.eos_token_id}")
|
||||
|
||||
#print(f"example['tokens']: {example['tokens']}")
|
||||
pad_amt = self.max_token_length-2 - len(example["tokens"])
|
||||
example['tokens']= F.pad(example['tokens'],pad=(0,pad_amt),mode='constant',value=0)
|
||||
example['tokens']= F.pad(example['tokens'],pad=(1,0),mode='constant',value=int(self.tokenizer.bos_token_id))
|
||||
eos_int = int(self.tokenizer.eos_token_id)
|
||||
#pad_amt = self.max_token_length-2 - len(example["tokens"])
|
||||
#example['tokens']= F.pad(example['tokens'],pad=(0,pad_amt),mode='constant',value=0)
|
||||
#example['tokens']= F.pad(example['tokens'],pad=(1,0),mode='constant',value=int(self.tokenizer.bos_token_id))
|
||||
#eos_int = int(self.tokenizer.eos_token_id)
|
||||
#eos_int = int(0)
|
||||
example['tokens']= F.pad(example['tokens'],pad=(0,1),mode='constant',value=eos_int)
|
||||
#example['tokens']= F.pad(example['tokens'],pad=(0,1),mode='constant',value=eos_int)
|
||||
#print(f"__getitem__ train_item['caption']: {train_item['caption']}")
|
||||
#print(f"__getitem__ train_item['pathname']: {train_item['pathname']}")
|
||||
#print(f"__getitem__ example['tokens'] pad: {example['tokens']}")
|
||||
|
@ -149,6 +128,9 @@ class EveryDreamBatch(Dataset):
|
|||
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
|
||||
|
||||
example["image"] = image_train_tmp.image
|
||||
example["caption"] = image_train_tmp.caption
|
||||
if random.random() > self.conditional_dropout:
|
||||
example["caption"] = image_train_tmp.caption
|
||||
else:
|
||||
example["caption"] = " "
|
||||
|
||||
return example
|
||||
|
|
|
@ -19,6 +19,7 @@ from torchvision import transforms, utils
|
|||
import random
|
||||
import math
|
||||
import os
|
||||
import logging
|
||||
|
||||
_RANDOM_TRIM = 0.04
|
||||
|
||||
|
@ -48,7 +49,9 @@ class ImageTrainItem():
|
|||
save: save the cropped image to disk, for manual inspection of resize/crop
|
||||
crop_jitter: randomly shift cropp by N pixels when using multiple aspect ratios to improve training quality
|
||||
"""
|
||||
if not hasattr(self, 'image') or len(self.image) == 0:
|
||||
#print(self.pathname, self.image)
|
||||
try:
|
||||
#if not hasattr(self, 'image'):
|
||||
self.image = PIL.Image.open(self.pathname).convert('RGB')
|
||||
|
||||
width, height = self.image.size
|
||||
|
@ -96,6 +99,10 @@ class ImageTrainItem():
|
|||
self.image = self.image.resize(self.target_wh, resample=PIL.Image.BICUBIC)
|
||||
|
||||
self.image = self.flip(self.image)
|
||||
except Exception as e:
|
||||
logging.error(f"Error loading image: {self.pathname}")
|
||||
print(e)
|
||||
exit()
|
||||
|
||||
if type(self.image) is not np.ndarray:
|
||||
if save:
|
||||
|
@ -106,7 +113,7 @@ class ImageTrainItem():
|
|||
|
||||
self.image = np.array(self.image).astype(np.uint8)
|
||||
|
||||
self.image = (self.image / 127.5 - 1.0).astype(np.float32)
|
||||
#self.image = (self.image / 127.5 - 1.0).astype(np.float32)
|
||||
|
||||
#print(self.image.shape)
|
||||
|
||||
|
|
58
train.py
58
train.py
|
@ -20,6 +20,7 @@ import signal
|
|||
import argparse
|
||||
import logging
|
||||
import time
|
||||
import gc
|
||||
|
||||
import torch.nn.functional as torch_functional
|
||||
from torch.cuda.amp import autocast
|
||||
|
@ -49,6 +50,8 @@ from data.every_dream import EveryDreamBatch
|
|||
from utils.convert_diffusers_to_stable_diffusion import convert as converter
|
||||
from utils.gpu import GPU
|
||||
|
||||
import debug
|
||||
|
||||
_GRAD_ACCUM_STEPS = 1 # future use...
|
||||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
@ -159,6 +162,11 @@ def main(args):
|
|||
seed = 555
|
||||
set_seed(seed)
|
||||
gpu = GPU()
|
||||
torch.backends.cudnn.benchmark = False
|
||||
args.clip_skip = max(min(2, args.clip_skip), 0)
|
||||
|
||||
if args.text_encoder_epochs is None or args.text_encoder_epochs < 1:
|
||||
args.text_encoder_epochs = _VERY_LARGE_NUMBER
|
||||
|
||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||
logging.info(" no checkpointing specified, defaulting to 20 minutes")
|
||||
|
@ -279,7 +287,7 @@ def main(args):
|
|||
i = 0
|
||||
for prompt in prompts:
|
||||
if prompt is None or len(prompt) < 2:
|
||||
logging.warning("empty prompt in sample prompts, check your prompts file")
|
||||
#logging.warning("empty prompt in sample prompts, check your prompts file")
|
||||
continue
|
||||
images = []
|
||||
for cfg in [7.0, 4.0, 1.01]:
|
||||
|
@ -301,14 +309,16 @@ def main(args):
|
|||
|
||||
clean_prompt = clean_filename(prompt)
|
||||
|
||||
result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.png")
|
||||
result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||
with open(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.txt", "w") as f:
|
||||
f.write(prompt)
|
||||
|
||||
tfimage = transforms.ToTensor()(result)
|
||||
if random_captions:
|
||||
log_writer.add_image(tag=f"sample_{i}", img_tensor=tfimage, global_step=gs)
|
||||
i += 1
|
||||
else:
|
||||
log_writer.add_image(tag=f"sample_{clean_prompt[:150]}", img_tensor=tfimage, global_step=gs)
|
||||
log_writer.add_image(tag=f"sample_{i}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=gs)
|
||||
i += 1
|
||||
|
||||
del result
|
||||
del tfimage
|
||||
|
@ -406,7 +416,7 @@ def main(args):
|
|||
for line in f:
|
||||
sample_prompts.append(line.strip())
|
||||
|
||||
log_folder = os.path.join("logs", f"{args.project_name}{log_time}")
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}{log_time}")
|
||||
logging.info(f"Logging to {log_folder}")
|
||||
|
||||
if False: #args.wandb is not None and args.wandb: # not yet supported
|
||||
|
@ -428,7 +438,7 @@ def main(args):
|
|||
|
||||
log_args(log_writer, args)
|
||||
|
||||
args.clip_skip = max(min(2, args.clip_skip), 0)
|
||||
|
||||
|
||||
"""
|
||||
Train the model
|
||||
|
@ -547,17 +557,20 @@ def main(args):
|
|||
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
|
||||
|
||||
if epoch == args.text_encoder_epochs:
|
||||
logging.info(f" Freezing text_encoder at epoch: {epoch}")
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
epoch_start_time = time.time()
|
||||
steps_pbar.reset()
|
||||
images_per_sec_epoch = []
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if args.debug:
|
||||
debug.debug_batch(batch)
|
||||
step_start_time = time.time()
|
||||
|
||||
if global_step > 0 and global_step > args.text_encoder_steps == 0:
|
||||
text_encoder.requires_grad_(False)
|
||||
text_encoder.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
with autocast():
|
||||
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
|
@ -576,10 +589,17 @@ def main(args):
|
|||
cuda_caption = batch["tokens"].to(text_encoder.device)
|
||||
|
||||
with autocast(enabled=args.amp):
|
||||
encoder_hidden_states = text_encoder(cuda_caption)
|
||||
#encoder_hidden_states = text_encoder(cuda_caption)
|
||||
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
|
||||
|
||||
# if clip_skip > 0: #TODO
|
||||
# encoder_hidden_states = encoder_hidden_states['last_hidden_state'][-clip_skip]
|
||||
#print("encoder_hidden_states.keys()", encoder_hidden_states.keys())
|
||||
#print("encoder_hidden_states.hidden_states.shape", encoder_hidden_states.hidden_states[0].shape)
|
||||
#print("encoder_hidden_states.last_hidden_state.shape", encoder_hidden_states.last_hidden_state[0].shape)
|
||||
|
||||
if args.clip_skip > 0: # TODO
|
||||
encoder_hidden_states = encoder_hidden_states.hidden_states[-args.clip_skip]
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
|
||||
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
|
@ -592,7 +612,7 @@ def main(args):
|
|||
#del noise, latents
|
||||
|
||||
with autocast(): # xformers requires autocast
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states.last_hidden_state).sample
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
#with autocast(enabled=args.amp):
|
||||
loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
@ -607,8 +627,8 @@ def main(args):
|
|||
loss.backward()
|
||||
#scaler.unscale_(optimizer)
|
||||
#if args.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=2)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=2)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=1)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=1)
|
||||
#scaler.step(optimizer)
|
||||
#scaler.update()
|
||||
optimizer.step()
|
||||
|
@ -654,6 +674,7 @@ def main(args):
|
|||
|
||||
del pipe
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
||||
|
||||
|
@ -714,7 +735,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
|
||||
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
|
||||
argparser.add_argument("--grad_accum", type=int, default=1, help="NONFUNCTIONING. Gradient accumulation factor (def: 1), (ex, 2)")
|
||||
argparser.add_argument("--clip_skip", type=int, default=0, help="NONFUNCTIONING. Train using penultimate layers (def: 0)", choices=[0, 1, 2])
|
||||
argparser.add_argument("--clip_skip", type=int, default=2, help="Train using penultimate layer (def: 2)", choices=[0, 1, 2])
|
||||
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
|
||||
argparser.add_argument("--mixed_precision", default="no", help="NONFUNCTIONING. precision, (default: NO for fp32)", choices=["NO", "fp16", "bf16"])
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
|
@ -724,7 +745,8 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
|
||||
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
|
||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||
argparser.add_argument("--text_encoder_steps", type=int, default=0, help="disable text encoder training after N steps (def: disabled)")
|
||||
argparser.add_argument("--text_encoder_epochs", type=int, default=0, help="disable text encoder training after N steps (def: disabled)")
|
||||
argparser.add_argument("--debug", action="store_true", default=False)
|
||||
args = argparser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
|
Loading…
Reference in New Issue