various tweaks and bugfixes over holidays

This commit is contained in:
Victor Hall 2022-12-27 14:25:32 -05:00
parent 1bfe6e97fb
commit 4c53f2d55c
8 changed files with 227 additions and 154 deletions

View File

@ -15,10 +15,12 @@ limitations under the License.
"""
import os
import logging
from PIL import Image
import random
from data.image_train_item import ImageTrainItem
import data.aspects as aspects
from colorama import Fore, Style
class DataLoaderMultiAspect():
"""
@ -28,14 +30,15 @@ class DataLoaderMultiAspect():
batch_size: number of images per batch
flip_p: probability of flipping image horizontally (i.e. 0-0.5)
"""
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512):
def __init__(self, data_root, seed=555, debug_level=0, batch_size=1, flip_p=0.0, resolution=512, log_folder=None):
self.image_paths = []
self.debug_level = debug_level
self.flip_p = flip_p
self.log_folder = log_folder
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
print(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
print(" Preloading images...")
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
logging.info(" Preloading images...")
self.__recurse_data_root(self=self, recurse_root=data_root)
random.Random(seed).shuffle(self.image_paths)
@ -54,7 +57,7 @@ class DataLoaderMultiAspect():
with open(file_path, encoding='utf-8', mode='r') as caption_file:
caption = caption_file.read()
except:
print(f" *** Error reading {file_path} to get caption, falling back to filename")
logging.error(f" *** Error reading {file_path} to get caption, falling back to filename")
caption = fallback_caption
pass
return caption
@ -78,20 +81,24 @@ class DataLoaderMultiAspect():
else:
caption = caption_from_filename
image = Image.open(pathname)
width, height = image.size
image_aspect = width / height
try:
image = Image.open(pathname)
width, height = image.size
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
image_train_item = ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=pathname, flip_p=flip_p)
image_train_item = ImageTrainItem(image=None, caption=caption, target_wh=target_wh, pathname=pathname, flip_p=flip_p)
decorated_image_train_items.append(image_train_item)
decorated_image_train_items.append(image_train_item)
except Exception as e:
logging.error(f"{Fore.LIGHTRED_EX} *** Error opening {Fore.LIGHTYELLOW_EX}{pathname}{Fore.LIGHTRED_EX} to get metadata. File may be corrupt and will be skipped.{Style.RESET_ALL}")
logging.error(f" *** exception: {e}")
pass
return decorated_image_train_items
@staticmethod
def __bucketize_images(prepared_train_data: list, batch_size=1, debug_level=0):
def __bucketize_images(self, prepared_train_data: list, batch_size=1, debug_level=0):
"""
Put images into buckets based on aspect ratio with batch_size*n images per bucket, discards remainder
"""
@ -105,16 +112,21 @@ class DataLoaderMultiAspect():
buckets[(target_wh[0],target_wh[1])] = []
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
print(f" ** Number of buckets used: {len(buckets)}")
logging.info(f" ** Number of buckets used: {len(buckets)}")
if len(buckets) > 1:
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
with open(os.path.join(self.log_folder, "bucket_drops.txt"), "a") as f:
f.write(f"{bucket} {truncate_count} dropped files:\n")
for item in buckets[bucket][-truncate_count:]:
f.write(f"- {item.pathname}\n")
current_bucket_size = len(buckets[bucket])
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
if debug_level > 0:
print(f" ** Bucket {bucket} with {current_bucket_size} will drop {truncate_count} images due to batch size {batch_size}")
logging.warning(f" ** Bucket {bucket} with {current_bucket_size} will drop {truncate_count} images due to batch size {batch_size}")
# flatten the buckets
image_caption_pairs = []
@ -131,9 +143,9 @@ class DataLoaderMultiAspect():
try:
with open(multiply_path, encoding='utf-8', mode='r') as f:
multiply = int(float(f.read().strip()))
print(f" * DLMA multiply.txt in {recurse_root} set to {multiply}")
logging.info(f" * DLMA multiply.txt in {recurse_root} set to {multiply}")
except:
print(f" *** Error reading multiply.txt in {recurse_root}, defaulting to 1")
logging.error(f" *** Error reading multiply.txt in {recurse_root}, defaulting to 1")
pass
for f in os.listdir(recurse_root):

View File

@ -13,6 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import torch
from torch.utils.data import Dataset
from data.data_loader import DataLoaderMultiAspect as dlma
@ -45,6 +46,7 @@ class EveryDreamBatch(Dataset):
crop_jitter=20,
seed=555,
tokenizer=None,
log_folder=None,
):
self.data_root = data_root
self.batch_size = batch_size
@ -53,6 +55,7 @@ class EveryDreamBatch(Dataset):
self.crop_jitter = crop_jitter
self.unloaded_to_idx = 0
self.tokenizer = tokenizer
self.log_folder = log_folder
#print(f"tokenizer: {tokenizer}")
self.max_token_length = self.tokenizer.model_max_length
@ -60,14 +63,18 @@ class EveryDreamBatch(Dataset):
seed = random.randint(0, 99999)
if not dls.shared_dataloader:
print(" * Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root, seed=seed, debug_level=debug_level, batch_size=self.batch_size, flip_p=flip_p, resolution=resolution)
logging.info(" * Creating new dataloader singleton")
dls.shared_dataloader = dlma(data_root=data_root,
seed=seed,
debug_level=debug_level,
batch_size=self.batch_size,
flip_p=flip_p,
resolution=resolution,
log_folder=self.log_folder,
)
self.image_train_items = dls.shared_dataloader.get_all_images()
# for iti in self.image_train_items:
# print(f"iti caption:{iti.caption}")
# exit()
self.num_images = len(self.image_train_items)
self._length = self.num_images
@ -79,9 +86,7 @@ class EveryDreamBatch(Dataset):
]
)
print()
print(f" ** Trainer Set: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}, length w/repeats: {self._length}")
print()
logging.info(f" ** Trainer Set: {self._length / batch_size:.0f}, num_images: {self.num_images}, batch_size: {self.batch_size}")
def __len__(self):
return self._length
@ -98,25 +103,8 @@ class EveryDreamBatch(Dataset):
padding="max_length",
max_length=self.tokenizer.model_max_length,
).input_ids
#print(example["tokens"])
example["tokens"] = torch.tensor(example["tokens"])
# else:
# example["tokens"] = torch.zeros(75, dtype=torch.int)
#print(f"bos: {self.tokenizer.bos_token_id}{self.tokenizer.eos_token_id}")
#print(f"example['tokens']: {example['tokens']}")
#pad_amt = self.max_token_length-2 - len(example["tokens"])
#example['tokens']= F.pad(example['tokens'],pad=(0,pad_amt),mode='constant',value=0)
#example['tokens']= F.pad(example['tokens'],pad=(1,0),mode='constant',value=int(self.tokenizer.bos_token_id))
#eos_int = int(self.tokenizer.eos_token_id)
#eos_int = int(0)
#example['tokens']= F.pad(example['tokens'],pad=(0,1),mode='constant',value=eos_int)
#print(f"__getitem__ train_item['caption']: {train_item['caption']}")
#print(f"__getitem__ train_item['pathname']: {train_item['pathname']}")
#print(f"__getitem__ example['tokens'] pad: {example['tokens']}")
example["caption"] = train_item["caption"] # for sampling if needed
#print(f"len tokens: {len(example['tokens'])} cap: {example['caption']}")
return example

View File

@ -100,8 +100,8 @@ class ImageTrainItem():
self.image = self.flip(self.image)
except Exception as e:
logging.error(f"Error loading image: {self.pathname}")
print(e)
logging.error(f"Fatal Error loading image: {self.pathname}:")
logging.error(e)
exit()
if type(self.image) is not np.ndarray:

View File

@ -8,7 +8,7 @@ You can train resolutions from 512 to 1024 in 64 pixel increments. General resu
For instance, if training from the base 1.5 model, you can try trying at 576, 640, or 704.
If you are training on a base model that is 768, such as SD 2.1 768-v, you should also probably use 768 as a base number and adjust from there.
If you are training on a base model that is 768, such as "SD 2.1 768-v", you should also probably use 768 as a base number and adjust from there.
## Log and ckpt save folders
@ -24,6 +24,16 @@ By default the CKPT format copies of ckpts that are peroidically saved are saved
--ckpt_dir "r:\webui\models\stable-diffusion"
This is useful if you want to dump the CKPT files directly to your webui/inference program model folder.
## Clip skip
Aka "penultimate layer", this takes the output from the text encoder not from its last output layer, but layers before.
--clip_skip 2 ^
A value of "2" is the canonical form of "penultimate layer" useed by various webuis, but 1 an 3 are accepted as well if you wish to experiment. Default is "0" which takes the "last hidden layer" or standard output of the text encoder as Stable Diffusion was originally designed. Training with this setting may necessititate or work better when also using the same setting in your webui/inference program.
## Conditional dropout
Conditional dropout means the prompt or caption on the training image is dropped, and the caption is "blank". The theory is this can help with unconditional guidance, per the original paper and authors of Latent Diffusion and Stable Diffusion.
@ -61,4 +71,35 @@ Example:
The above example with combine the loss from 2 batches before applying updates. This *may* be a good idea for higher resolution training that requires smaller batch size but mega batch sizes are also not the be-all-end all.
Some experimentation shows if you already have batch size in the 6-8 range than grad accumulation of more than 2 just reduces quality, but you can experiment.
Some experimentation shows if you already have batch size in the 6-8 range than grad accumulation of more than 2 just reduces quality, but you can experiment.
## Flip_p
If you wish for your training images to be randomly flipped horizontally, use this to flip the images 50% of the time:
--flip_p 0.5 ^
This is useful for styles or other training that is not symmetrical. It is not suggested for training specific human faces as it may wash out facial features. It is also not suggested if any of your captions included directions like "left" or "right". Default is 0.0 (no flipping)
# Stuff you probably don't need to mess with
## log_step
Change how often log items are written. Default is 25 and probably good for most situations. This does not affect how often samples or ckpts are saved, just log scalar items.
--log_step 50 ^
## scale_lr
Attempts to automatically scale your learning rate up or down base on changes to batch size and gradient accumulation.
--scale_lr ^
This multiplies your ```--lr``` setting by ```sqrt of (batch_size times grad_accum)```. This can be useful if you're tweaking batch size and grad accum and want to keep your LR to a sane value.
## clip_grad_norm
Clips the gradient normals to a maximum value. This is an experimental feature, you can read online about gradient clipping. Default is None (no clipping). This is typically used for gradient explosion problems, but might be a fun thing to experiment with.
--clip_grad_norm 1.0 ^

View File

@ -42,7 +42,7 @@ Training from SD2 512 base model, 18 epochs, 4 batch size, 1.2e-6 learning rate,
--ckpt_every_n_minutes 30 ^
--useadam8bit
Training from the "SD21" model on the "jets" dataset on another drive, for 50 epochs, 6 batch size, 1.5e-6 learning rate, cosine scheduler that will decay in 1500 steps, generate samples evern 100 steps, 30 minute checkpoint interval, adam8bit:
Training from the "SD21" model on the "jets" dataset on another drive, for 50 epochs, 6 batch size, 1.5e-6 learning rate, cosine scheduler that will decay in 1500 steps, generate samples evern 100 steps, save a checkpoint every 20 epochs, and use AdamW 8bit optimizer:
python train.py --resume_ckpt "SD21" ^
--data_root "R:\everydream-trainer\training_samples\mega\gt\objects\jets" ^
@ -54,25 +54,20 @@ Training from the "SD21" model on the "jets" dataset on another drive, for 50 ep
--batch_size 6 ^
--sample_steps 100 ^
--lr 1.5e-6 ^
--ckpt_every_n_minutes 30 ^
--save_every_n_epochs 20 ^
--useadam8bit
Copy paste the above to your command line and press enter.
Make sure the last line does not have ^ but all other lines do
Scheduler example, note warmup and decay dont work with constant (default), warmup is set automatically to 5% of decay if not set
--lr_scheduler cosine
--lr_warmup_steps 100
--lr_decay_steps 2500
Warmup and decay only count for some schedulers! Constant is not one of them.
Currently "constant" and "cosine" are recommended. I'll add support to others upon request.
Make sure the last line does not have ^ but all other lines do. If you want you can put the command all on one line and not use the ^ carats instead.
## How to resume
Point your resume_ckpt to the path in logs like so:
```--resume_ckpt "R:\everydream2trainer\logs\myproj20221213-161620\ckpts\myproj-ep22-gs01099" ^```
Or use relative pathing:
```--resume_ckpt "logs\myproj20221213-161620\ckpts\myproj-ep22-gs01099" ^```
You should point to the folder in the logs per above if you want to resume rather than running a conversion back on a 2.0GB or 2.5GB pruned file if possible.

View File

@ -37,7 +37,7 @@ If you wish instead to save every certain number of epochs, you can set the minu
## Learning Rate
The learning rate affects how much "training" is done on the model. It is a very careful balance to select a value that will learn your data, but not overfit it. If you set the LR too high, the model will "fry" or could "overtrain" and become too rigid, only learning to exactly mimick your training data images and will not be able to generalize to new data or be "stylable". If you set the LR too low, you may take longer to train, or it may have difficulty learning the concepts at all. Usually sane values are 1e-6 to 3e-6
The learning rate affects how much "training" is done on the model. It is a very careful balance to select a value that will learn your data, but not overfit it. If you set the LR too high, the model will "fry" or could "overtrain" and become too rigid, only learning to exactly mimick your training data images and will not be able to generalize to new data or be "stylable". If you set the LR too low, you may take longer to train, or it may have difficulty learning the concepts at all. Usually sane values are 1e-6 to 3e-6.
## Batch Size
@ -62,16 +62,16 @@ The constant scheduler is the default and keeps your LR set to the value you set
The AdamW optimizer is the default and what was used by EveryDream 1.0. It's a good optimizer for stable diffusion and appears to be what was used to train SD itself.
AdamW 8bit is quite a bit faster and uses less VRAM. I currently **recommend** using it for most cases as it seems worth a potential reduction in quality.
AdamW 8bit is quite a bit faster and uses less VRAM. I currently **recommend** using it for most cases as it seems worth a potential reduction in quality for a significant speed boost and lower VRAM cost.
--useadam8bit ^
## Sampling
You can set your own sample prompts by adding them, one line at a time, to sample_prompts.txt. Or you can point to another file with --sample_prompts_file.
You can set your own sample prompts by adding them, one line at a time, to sample_prompts.txt. Or you can point to another file with --sample_prompts.
--sample_prompts "project_XYZ_test_prompts.txt" ^
Keep in mind a lot of prompts will take longer to generate. You may also want to adjust sample_steps to a different value to get samples left often. This is probably a good idea when training a larger dataset that you know will take longer to train, where more frequent samples will not help you.
Keep in mind a longer list of prompts will take longer to generate. You may also want to adjust sample_steps to a different value to get samples left often. This is probably a good idea when training a larger dataset that you know will take longer to train, where more frequent samples will not help you.
--sample_steps 500 ^

183
train.py
View File

@ -21,6 +21,7 @@ import argparse
import logging
import time
import gc
import random
import torch.nn.functional as torch_functional
from torch.cuda.amp import autocast
@ -63,6 +64,7 @@ def clean_filename(filename):
def convert_to_hf(ckpt_path):
hf_cache = os.path.join("ckpt_cache", os.path.basename(ckpt_path))
from utils.patch_unet import patch_unet
if os.path.isfile(ckpt_path):
if not os.path.exists(hf_cache):
@ -74,8 +76,13 @@ def convert_to_hf(ckpt_path):
except:
logging.info("Please manually convert the checkpoint to Diffusers format, see readme.")
exit()
else:
logging.info(f"Found cached checkpoint at {hf_cache}")
patch_unet(hf_cache)
return hf_cache
elif os.path.isdir(hf_cache):
patch_unet(hf_cache)
return hf_cache
else:
return ckpt_path
@ -151,24 +158,20 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
def main(args):
"""
Main entry point
"""
log_time = setup_local_logger(args)
seed = 555
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
set_seed(seed)
gpu = GPU()
torch.backends.cudnn.benchmark = False
args.clip_skip = max(min(2, args.clip_skip), 0)
if args.text_encoder_epochs is None or args.text_encoder_epochs < 1:
args.text_encoder_epochs = _VERY_LARGE_NUMBER
args.clip_skip = max(min(3, args.clip_skip), 0)
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(" no checkpointing specified, defaulting to 20 minutes")
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
args.ckpt_every_n_minutes = 20
if args.ckpt_every_n_minutes is None or args.ckpt_every_n_minutes < 1:
@ -178,12 +181,26 @@ def main(args):
args.save_every_n_epochs = _VERY_LARGE_NUMBER
if args.save_every_n_epochs < _VERY_LARGE_NUMBER and args.ckpt_every_n_minutes < _VERY_LARGE_NUMBER:
logging.warning(f"{Fore.YELLOW}Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
logging.warning(f"{Fore.YELLOW}save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTYELLOW_EX}Both save_every_n_epochs and ckpt_every_n_minutes are set, this will potentially spam a lot of checkpoints{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTYELLOW_EX}save_every_n_epochs: {args.save_every_n_epochs}, ckpt_every_n_minutes: {args.ckpt_every_n_minutes}{Style.RESET_ALL}")
if args.cond_dropout > 0.26:
logging.warning(f"{Fore.YELLOW}cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
logging.warning(f"{Fore.LIGHTYELLOW_EX}cond_dropout is set fairly high: {args.cond_dropout}, make sure this was intended{Style.RESET_ALL}")
total_batch_size = args.batch_size * args.grad_accum
if args.grad_accum > 1:
logging.info(f"{Fore.CYAN} Batch size: {args.batch_size}, grad accum: {args.grad_accum}, 'effective' batch size: {args.batch_size * args.grad_accum}{Style.RESET_ALL}")
if args.scale_lr is not None and args.scale_lr:
tmp_lr = args.lr
args.lr = args.lr * (total_batch_size**0.5)
logging.info(f"{Fore.CYAN} * Scaling learning rate {tmp_lr} by {total_batch_size**0.5}, new value: {args.lr}{Style.RESET_ALL}")
log_folder = os.path.join(args.logdir, f"{args.project_name}{log_time}")
logging.info(f"Logging to {log_folder}")
if not os.path.exists(log_folder):
os.makedirs(log_folder)
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, save_ckpt_dir):
@ -239,15 +256,14 @@ def main(args):
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception as ex:
print("failed to load xformers, continuing without it")
logging.warning("failed to load xformers, continuing without it")
pass
return pipe
def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int):
def __generate_sample(pipe: StableDiffusionPipeline, prompt : str, cfg: float, resolution: int, gen):
"""
generates a single sample at a given cfg scale and saves it to disk
"""
gen = torch.Generator(device="cuda").manual_seed(555)
"""
with torch.no_grad(), autocast():
image = pipe(prompt,
num_inference_steps=30,
@ -282,6 +298,8 @@ def main(args):
generates samples at different cfg scales and saves them to disk
"""
logging.info(f"Generating samples gs:{gs}, for {prompts}")
seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
gen = torch.Generator(device="cuda").manual_seed(seed)
i = 0
for prompt in prompts:
@ -290,7 +308,7 @@ def main(args):
continue
images = []
for cfg in [7.0, 4.0, 1.01]:
image = __generate_sample(pipe, prompt, cfg, resolution=resolution)
image = __generate_sample(pipe, prompt, cfg, resolution=resolution, gen=gen)
images.append(image)
width = 0
@ -311,6 +329,7 @@ def main(args):
result.save(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{log_folder}/samples/gs{gs:05}-{clean_prompt[:100]}.txt", "w") as f:
f.write(prompt)
f.write(f"\n seed: {seed}")
tfimage = transforms.ToTensor()(result)
if random_captions:
@ -325,9 +344,10 @@ def main(args):
try:
hf_ckpt_path = convert_to_hf(args.resume_ckpt)
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder", torch_dtype=torch.float32 if not args.amp else torch.float16)
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae", torch_dtype=torch.float32 if not args.amp else torch.float16)
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet", torch_dtype=torch.float32)
text_encoder = CLIPTextModel.from_pretrained(hf_ckpt_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(hf_ckpt_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(hf_ckpt_path, subfolder="unet")
#unet.upcast_attention(True)
scheduler = DDIMScheduler.from_pretrained(hf_ckpt_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(hf_ckpt_path, subfolder="tokenizer", use_fast=False)
except:
@ -335,8 +355,10 @@ def main(args):
if is_xformers_available():
try:
#pass
unet.enable_xformers_memory_efficient_attention()
logging.info(" Enabled memory efficient attention (xformers)")
#unet.set_attention_slice(4)
#logging.info(" Enabled memory efficient attention")
except Exception as e:
logging.warning(
"Could not enable memory efficient attention. Make sure xformers is installed"
@ -344,11 +366,11 @@ def main(args):
)
default_lr = 2e-6
lr = args.lr if args.lr is not None else default_lr
curr_lr = args.lr if args.lr is not None else default_lr
vae = vae.to(torch.device("cuda"), dtype=torch.float32)
unet = unet.to(torch.device("cuda"))
text_encoder = text_encoder.to(torch.device("cuda"))
vae = vae.to(torch.device("cuda"), dtype=torch.float32 if not args.amp else torch.float16)
unet = unet.to(torch.device("cuda"), dtype=torch.float32 if not args.amp else torch.float16)
text_encoder = text_encoder.to(torch.device("cuda"), dtype=torch.float32 if not args.amp else torch.float16)
if args.disable_textenc_training:
logging.info(f"{Fore.CYAN} * NOT Training Text Encoder, quality reduced *{Style.RESET_ALL}")
@ -366,7 +388,7 @@ def main(args):
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(
itertools.chain(params_to_train),
lr=lr,
lr=curr_lr,
betas=betas,
eps=epsilon,
weight_decay=weight_decay,
@ -375,23 +397,25 @@ def main(args):
logging.info(f"{Fore.CYAN} * Using AdamW standard Optimizer *{Style.RESET_ALL}")
optimizer = torch.optim.AdamW(
itertools.chain(params_to_train),
lr=lr,
lr=curr_lr,
betas=betas,
eps=epsilon,
weight_decay=weight_decay,
amsgrad=False,
)
log_optimizer(optimizer, betas, epsilon)
#log_optimizer(optimizer, betas, epsilon)
train_batch = EveryDreamBatch(
data_root=args.data_root,
flip_p=0.0,
flip_p=args.flip_p,
debug_level=1,
batch_size=args.batch_size,
conditional_dropout=args.cond_dropout,
resolution=args.resolution,
tokenizer=tokenizer,
seed = seed,
log_folder=log_folder,
)
torch.cuda.benchmark = False
@ -399,7 +423,7 @@ def main(args):
epoch_len = math.ceil(len(train_batch) / args.batch_size)
if args.lr_decay_steps is None or args.lr_decay_steps < 1:
args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.3)
args.lr_decay_steps = int(epoch_len * args.max_epochs * 1.5)
lr_warmup_steps = int(args.lr_decay_steps / 50) if args.lr_warmup_steps is None else args.lr_warmup_steps
@ -415,8 +439,6 @@ def main(args):
for line in f:
sample_prompts.append(line.strip())
log_folder = os.path.join(args.logdir, f"{args.project_name}{log_time}")
logging.info(f"Logging to {log_folder}")
if False: #args.wandb is not None and args.wandb: # not yet supported
log_writer = wandb.init(project="EveryDream2FineTunes",
@ -478,15 +500,15 @@ def main(args):
logging.info(f" saving ckpts every {args.ckpt_every_n_minutes} minutes")
logging.info(f" saving ckpts every {args.save_every_n_epochs } epochs")
scaler = torch.cuda.amp.GradScaler(
#enabled=False,
enabled=True if args.amp else False,
init_scale=2**1,
growth_factor=1.000001,
backoff_factor=0.9999999,
growth_interval=50,
)
logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
# scaler = torch.cuda.amp.GradScaler(
# #enabled=False,
# enabled=True if args.amp else False,
# init_scale=2**1,
# growth_factor=1.000001,
# backoff_factor=0.9999999,
# growth_interval=50,
# )
#logging.info(f" Grad scaler enabled: {scaler.is_enabled()}")
def collate_fn(batch):
"""
@ -514,11 +536,7 @@ def main(args):
collate_fn=collate_fn
)
total_batch_size = args.batch_size * args.grad_accum
unet.train()
text_encoder.requires_grad_(True)
text_encoder.train()
logging.info(f" unet device: {unet.device}, precision: {unet.dtype}, training: {unet.training}")
@ -548,6 +566,7 @@ def main(args):
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer)
torch.cuda.empty_cache()
loss = torch.tensor(0.0, device=torch.device("cuda"), dtype=torch.float32)
try:
for epoch in range(args.max_epochs):
@ -555,27 +574,24 @@ def main(args):
logging.info(f" Saving model")
save_path = os.path.join(f"{log_folder}/ckpts/{args.project_name}-ep{epoch:02}-gs{global_step:05}")
__save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, args.save_ckpt_dir)
if epoch == args.text_encoder_epochs:
logging.info(f" Freezing text_encoder at epoch: {epoch}")
text_encoder.requires_grad_(False)
text_encoder.eval()
torch.cuda.empty_cache()
epoch_start_time = time.time()
steps_pbar.reset()
images_per_sec_epoch = []
for step, batch in enumerate(train_dataloader):
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
with torch.no_grad():
with autocast():
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
#with autocast():
pixel_values = batch["image"].to(memory_format=torch.contiguous_format).to(unet.device)
with autocast(enabled=args.amp):
latents = vae.encode(pixel_values, return_dict=False)
latent = latents[0]
latents = latent.sample()
latents = latents * 0.18215
latents = latents * 0.18215
noise = torch.randn_like(latents)
bsz = latents.shape[0]
@ -583,17 +599,12 @@ def main(args):
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
cuda_caption = batch["tokens"].to(text_encoder.device)
cuda_caption = batch["tokens"].to(text_encoder.device)
with autocast(enabled=args.amp):
#encoder_hidden_states = text_encoder(cuda_caption)
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
#with autocast(enabled=args.amp):
encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True)
#print("encoder_hidden_states.keys()", encoder_hidden_states.keys())
#print("encoder_hidden_states.hidden_states.shape", encoder_hidden_states.hidden_states[0].shape)
#print("encoder_hidden_states.last_hidden_state.shape", encoder_hidden_states.last_hidden_state[0].shape)
if args.clip_skip > 0: # TODO
if args.clip_skip > 0:
encoder_hidden_states = encoder_hidden_states.hidden_states[-args.clip_skip]
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
@ -606,35 +617,24 @@ def main(args):
target = scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
#del noise, latents
del noise, latents
with autocast(): # xformers requires autocast
with autocast(enabled=args.amp):
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
del timesteps, encoder_hidden_states, noisy_latents
#with autocast(enabled=args.amp):
loss = torch_functional.mse_loss(model_pred.float(), target.float(), reduction="mean")
#del timesteps, encoder_hidden_states, noisy_latents
if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=args.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=args.clip_grad_norm)
if args.amp:
with autocast():
loss.backward()
#scaler.unscale_(optimizer)
#if args.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=unet.parameters(), max_norm=1)
torch.nn.utils.clip_grad_norm_(parameters=text_encoder.parameters(), max_norm=1)
#scaler.step(optimizer)
#scaler.update()
loss.backward()
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
optimizer.step()
optimizer.zero_grad()
else:
loss.backward()
if ((global_step + 1) % args.grad_accum == 0) or (step == epoch_len - 1):
optimizer.step()
optimizer.zero_grad()
optimizer.zero_grad(set_to_none=True)
lr_scheduler.step()
@ -645,10 +645,10 @@ def main(args):
images_per_sec_epoch.append(images_per_sec)
if (global_step + 1) % args.log_step == 0:
lr = lr_scheduler.get_last_lr()[0]
logs = {"loss/step": loss.detach().item(), "lr": lr, "img/s": images_per_sec}
curr_lr = lr_scheduler.get_last_lr()[0]
logs = {"loss/step": loss.detach().item(), "lr": curr_lr, "img/s": images_per_sec}
log_writer.add_scalar(tag="loss/step", scalar_value=loss, global_step=global_step)
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=lr, global_step=global_step)
log_writer.add_scalar(tag="hyperparamater/lr", scalar_value=curr_lr, global_step=global_step)
sum_img = sum(images_per_sec_epoch)
avg = sum_img / len(images_per_sec_epoch)
images_per_sec_epoch = []
@ -717,24 +717,23 @@ if __name__ == "__main__":
argparser = argparse.ArgumentParser(description="EveryDream Training options")
argparser.add_argument("--resume_ckpt", type=str, required=True, default="sd_v1-5_vae.ckpt")
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant scheduler")
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default")
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--ckpt_every_n_minutes", type=int, default=None, help="Save checkpoint every n minutes, def: 20")
argparser.add_argument("--save_every_n_epochs", type=int, default=None, help="Save checkpoint every n epochs, def: 0 (disabled)")
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="File with prompts to generate test samples from (def: sample_prompts.txt)")
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False) NOT RECOMMENDED")
argparser.add_argument("--batch_size", type=int, default=2, help="Batch size (def: 2)")
argparser.add_argument("--clip_grad_norm", type=float, default=None, help="Clip gradient norm (def: disabled) (ex: 1.5), useful if loss=nan?")
argparser.add_argument("--grad_accum", type=int, default=1, help="NONFUNCTIONING. Gradient accumulation factor (def: 1), (ex, 2)")
argparser.add_argument("--clip_skip", type=int, default=2, help="Train using penultimate layer (def: 2)", choices=[0, 1, 2])
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
argparser.add_argument("--clip_skip", type=int, default=0, help="Train using penultimate layer (def: 0) (2 is 'penultimate')", choices=[0, 1, 2, 3])
argparser.add_argument("--data_root", type=str, default="input", help="folder where your training images are")
argparser.add_argument("--mixed_precision", default="no", help="NONFUNCTIONING. precision, (default: NO for fp32)", choices=["NO", "fp16", "bf16"])
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
@ -742,7 +741,9 @@ if __name__ == "__main__":
argparser.add_argument("--cond_dropout", type=float, default=0.04, help="Conditional drop out as decimal 0.0-1.0, see docs for more info (def: 0.04)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
argparser.add_argument("--text_encoder_epochs", type=int, default=0, help="disable text encoder training after N steps (def: disabled)")
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5")
args = argparser.parse_args()
main(args)

36
utils/patch_unet.py Normal file
View File

@ -0,0 +1,36 @@
"""
Copyright [2022] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import json
import logging
def patch_unet(ckpt_path):
"""
Patch the UNet to use updated attention heads for xformers support in FP32
"""
unet_cfg_path = os.path.join(ckpt_path, "unet", "config.json")
with open(unet_cfg_path, "r") as f:
unet_cfg = json.load(f)
if unet_cfg["attention_head_dim"] == [5, 10, 20, 20]:
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
return
unet_cfg["attention_head_dim"] = [5, 10, 20, 20]
logging.info(f" unet attention_head_dim: {unet_cfg['attention_head_dim']}")
with open(unet_cfg_path, "w") as f:
json.dump(unet_cfg, f, indent=2)