few sanity checks and remove keyboard

This commit is contained in:
Victor Hall 2023-01-20 09:42:24 -05:00
parent e09249b699
commit 1c2708dc63
6 changed files with 65 additions and 33 deletions

View File

@ -47,6 +47,7 @@ class DataLoaderMultiAspect():
self.log_folder = log_folder
self.seed = seed
self.batch_size = batch_size
self.has_scanned = False
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
@ -203,6 +204,9 @@ class DataLoaderMultiAspect():
"""
decorated_image_train_items = []
if not self.has_scanned:
undersized_images = []
for pathname in tqdm.tqdm(image_paths):
caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0]
caption = DataLoaderMultiAspect.__split_caption_into_tags(caption_from_filename)
@ -225,15 +229,30 @@ class DataLoaderMultiAspect():
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
if not self.has_scanned:
if width * height < target_wh[0] * target_wh[1]:
undersized_images.append(f" *** {pathname} with size: {width},{height} is smaller than target size: {target_wh}, consider using larger images")
image_train_item = ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=pathname, flip_p=flip_p)
decorated_image_train_items.append(image_train_item)
except Exception as e:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {e}")
pass
if not self.has_scanned:
self.has_scanned = True
if len(undersized_images) > 0:
underized_log_path = os.path.join(self.log_folder, "undersized_images.txt")
logging.warning(f"{Fore.LIGHTRED_EX} ** Some images are smaller than the target size, consider using larger images{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTRED_EX} ** Check {underized_log_path} for more information.{Style.RESET_ALL}")
with open(underized_log_path, "w") as undersized_images_file:
undersized_images_file.write(f" The following images are smaller than the target size, consider removing or sourcing a larger copy:")
for undersized_image in undersized_images:
undersized_images_file.write(undersized_image)
return decorated_image_train_items
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:

View File

@ -108,15 +108,17 @@ Some experimentation shows if you already have batch size in the 6-8 range than
## Gradient checkpointing
While traditionally used to reduce VRAM for smaller GPUs, gradient checkpointing can offer a higher batch size and/or higher resolution within whatever VRAM you have, so it may be useful even on a 24GB+ GPU.
This is mostly useful to reduce VRAM for smaller GPUs, and together with AdamW 8 bit and AMP mode can enable <12GB GPU training.
Gradient checkpointing can also offer a higher batch size and/or higher resolution within whatever VRAM you have, so it may be useful even on a 24GB+ GPU if you specifically want to run a very large batch size. The other option is using gradient accumulation instead.
--gradient_checkpointing ^
This drastically reduces VRAM (by many GB) and will allow quite a larger batch size or resolution, for example, 13-14 instead of 7-8 on a 24GB card using 512 training resolution.
While gradient checkpointing reduces performance, the ability to run a higher batch size brings performance back fairly close to without it.
While gradient checkpointing reduces performance, the ability to run a higher batch size brings performance back fairly close to without it. My personal tests show a 25% performance hit simply turning on gradient checkpointing on a 3090 (batch 7, 512), but almost all of that is made up by the ability to use a larger batch size (up to 14). You may NOT want to use a batch size as large as 13-14, or you may find you need to tweak learning rate all over again to find the right balance.
You may NOT want to use a batch size as large as 13-14, or you may find you need to tweak learning rate all over again to find the right balance. Generally I would not turn it on for a 24GB GPU training at <640 resolution.
This probably IS a good idea for training at higher resolutions. Balancing this toggle, resolution, batch_size, and grad_accum will take some experimentation, but you might try using this with 768+ resolutions, grad_accum 3-4, and then as high of a batch size as you can get to work without crashing, while adjusting LR with respect to your (batch_size * grad_accum) value.
This probably IS a good idea for training at higher resolutions and allows >768 training on 24GB GPUs. Balancing this toggle, resolution, and batch_size will take a few quick experiments to see what you can run safely.
## Flip_p

View File

@ -16,6 +16,17 @@ You may wish to consider adding "sd1" or "sd2v" or similar to remember what the
--project_name "jets_sd21768v" ^
## Stuff you probably want on
--amp
Enables automatic mixed precision. Greatly improved training speed and can help a bit with VRAM use. [Torch](https://pytorch.org/docs/stable/amp.html) will automatically use FP16 precision for specific model components where FP16 is sufficient precision, and FP32 otherwise. This also enables xformers to work with the SD1.x attention head schema.
--useadam8bit
Uses [Tim Dettmer's reduced precision AdamW 8 Bit optimizer](https://github.com/TimDettmers/bitsandbytes). This seems to have no noticeable impact on quality but is considerable faster and more VRAM efficient. See more below in AdamW vs AdamW 8bit.
## Epochs
EveryDream 2.0 has done away with repeats and instead you should set your max_epochs. Changing epochs has the same effect as changing repeats in DreamBooth or EveryDream1. For example, if you had 50 repeats and 5 epochs, you would now set max_epochs to 250 (50x5=250). This is a bit more intuitive as there is no more double meaning for epochs and repeats.
@ -28,6 +39,16 @@ With more training data for your subjects and concepts, you can slowly scale thi
With less training data, this value should be higher, because more repetition on the images is needed to learn.
## Resolution
The resolution for training. All buckets for multiaspect will be based on the total pixel count of your resolution squared.
--resolution 768
Current supported resolutions can be printed by running the trainer without any arugments.
python train.py
## Save interval for checkpoints
While EveryDream 1.0 saved a checkpoint every epoch, this is no longer the case as it would produce too many files as "repeats" are removed in favor of just using epochs instead. To balance the fact EveryDream users are sometimes training small datasets and sometimes huge datasets, you can now set the interval at which checkpoints are saved. The default is 30 minutes, but you can change it to whatever you want.
@ -76,16 +97,6 @@ At this time, ED2.0 supports constant or cosine scheduler.
The constant scheduler is the default and keeps your LR set to the value you set in the command line. That's really it for constant! I recommend sticking with it until you are comfortable with general training. More info in the [Advanced Tweaking](ATWEAKING.md) document.
## AdamW vs AdamW 8bit
The AdamW optimizer is the default and what was used by EveryDream 1.0. It's a good optimizer for Stable Diffusion and appears to be what was used to train SD itself.
AdamW 8bit is quite a bit faster and uses less VRAM while still having the same basic behavior. I currently **recommend** using it for most cases as it seems worth a potential slight reduction in quality for a *significant speed boost and lower VRAM cost*.
--useadam8bit ^
This may become a default in the future, and replaced with an option to use standard AdamW instead. For now, it's an option, *but I recommend always using it.*
## Sampling
You can set your own sample prompts by adding them, one line at a time, to sample_prompts.txt. Or you can point to another file with --sample_prompts.

BIN
doc/runpodinstance.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 37 KiB

View File

@ -47,8 +47,6 @@ from accelerate.utils import set_seed
import wandb
from torch.utils.tensorboard import SummaryWriter
import keyboard
from data.every_dream import EveryDreamBatch
from utils.convert_diff_to_ckpt import convert as converter
from utils.gpu import GPU
@ -88,7 +86,7 @@ def convert_to_hf(ckpt_path):
import utils.convert_original_stable_diffusion_to_diffusers as convert
convert.convert(ckpt_path, f"ckpt_cache/{ckpt_path}")
except:
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.")
logging.info("Please manually convert the checkpoint to Diffusers format (one time setup), see readme.")
exit()
else:
logging.info(f"Found cached checkpoint at {hf_cache}")
@ -350,16 +348,17 @@ def main(args):
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
if yaml:
yaml_save_path = f"{os.path.basename(save_path)}.yaml"
if yaml_name:
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml, yaml_save_path)
shutil.copyfile(yaml_name, yaml_save_path)
# optimizer_path = os.path.join(save_path, "optimizer.pt")
# if self.save_optimizer_flag:
@ -474,7 +473,6 @@ def main(args):
sample_scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
logging.info(f" Inferred yaml: {yaml}, attention head type: {'sd1' if is_sd1attn else 'sd2'}")
except:
logging.ERROR(" * Failed to load checkpoint *")
@ -482,13 +480,14 @@ def main(args):
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()
if not args.disable_xformers and (args.amp and is_sd1attn) or (not is_sd1attn):
try:
unet.enable_xformers_memory_efficient_attention()
logging.info("Enabled xformers")
except Exception as ex:
logging.warning("failed to load xformers, continuing without it")
pass
if not args.disable_xformers:
if (args.amp and is_sd1attn) or (not is_sd1attn):
try:
unet.enable_xformers_memory_efficient_attention()
logging.info("Enabled xformers")
except Exception as ex:
logging.warning("failed to load xformers, continuing without it")
pass
else:
logging.info("xformers not available or disabled")
@ -704,6 +703,8 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
loss_log_step = []
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
try:
for epoch in range(args.max_epochs):
@ -754,16 +755,13 @@ def main(args):
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del timesteps, encoder_hidden_states, noisy_latents
#del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
del target, model_pred
#if args.amp:
scaler.scale(loss).backward()
#else:
# loss.backward()
if args.clip_grad_norm is not None:
if not args.disable_unet_training:
@ -819,7 +817,7 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
if (not args.notebook and keyboard.is_pressed("ctrl+alt+page up")) or ((global_step + 1) % args.sample_steps == 0):
if (global_step + 1) % args.sample_steps == 0:
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
pipe = pipe.to(device)

View File

@ -46,4 +46,6 @@ def patch_unet(ckpt_path):
else:
raise ValueError(f"Unknown model format for: {prediction_type} and attention_head_dim {unet_cfg['attention_head_dim']}")
logging.info(f"Inferred yaml: {yaml}, attn: {'sd1' if is_sd1attn else 'sd2'}, prediction_type: {prediction_type}")
return is_sd1attn, yaml