Merge pull request #92 from victorchall/lion

fix mem leak on huge data, rework optimizer to separate json, add lio…
This commit is contained in:
Victor Hall 2023-02-25 16:20:10 -05:00 committed by GitHub
commit e7fc71ffa1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 201 additions and 42 deletions

View File

@ -35,6 +35,8 @@ Make sure to check out the [tools repo](https://github.com/victorchall/EveryDrea
[Advanced Tweaking](doc/ATWEAKING.md) - More stuff to tweak once you are comfortable
[Advanced Optimizer Tweaking](/doc/OPTIMIZER.md) - Even more stuff to tweak if you are *very adventurous*
[Chaining training sessions](doc/CHAINING.md) - Modify training parameters by chaining training sessions together end to end
[Shuffling Tags](doc/SHUFFLING_TAGS.md)
@ -50,3 +52,5 @@ Make sure to check out the [tools repo](https://github.com/victorchall/EveryDrea
[Free tier Google Colab notebook](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/Train_Colab.ipynb)
[RunPod / Vast](/doc/CLOUD_SETUP.md)
[Docker image link](https://github.com/victorchall/EveryDream2trainer/pkgs/container/everydream2trainer)

44
cfgs/train_coco.json Normal file
View File

@ -0,0 +1,44 @@
{
"amp": true,
"batch_size": 10,
"ckpt_every_n_minutes": null,
"clip_grad_norm": null,
"clip_skip": 0,
"cond_dropout": 1.00,
"data_root": "Q:/mldata/coco2017/train2017",
"disable_textenc_training": false,
"disable_xformers": false,
"flip_p": 0.0,
"gpuid": 0,
"gradient_checkpointing": false,
"grad_accum": 1,
"logdir": "logs",
"log_step": 25,
"lowvram": false,
"lr": 1.0e-06,
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 30,
"notebook": false,
"optimizer_config": "optimizer.json",
"project_name": "coco_test",
"resolution": 512,
"resume_ckpt": "SD15",
"run_name": null,
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"save_ckpt_dir": null,
"save_ckpts_from_n_epochs": 0,
"save_every_n_epochs": 20,
"save_optimizer": false,
"scale_lr": false,
"seed": 555,
"shuffle_tags": false,
"validation_config": "cfgs/validation_coco.json",
"wandb": true,
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.04
}

20
cfgs/validation_coco.json Normal file
View File

@ -0,0 +1,20 @@
{
"documentation": {
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
"every_n_epochs": "How often to run validation (1=every epoch).",
"seed": "The seed to use when running validation and stabilization passes."
},
"validate_training":"true",
"val_split_mode": "manual",
"val_data_root": "Q:/mldata/coco2017/val2017",
"val_split_proportion": 0.15,
"stabilize_training_loss": true,
"stabilize_split_proportion": 0.15,
"every_n_epochs": 1,
"seed": 555
}

View File

@ -129,9 +129,11 @@ class EveryDreamBatch(Dataset):
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
example["image"] = image_train_tmp.image
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
image_train_tmp.image = None # hack for now to avoid memory leak
example["caption"] = image_train_tmp.caption
example["runt_size"] = image_train_tmp.runt_size
return example
def __update_image_train_items(self, dropout_fraction: float):

39
doc/OPTIMIZER.md Normal file
View File

@ -0,0 +1,39 @@
# Advanced optimizer tweaking
You can set advanced optimizer settings using this arg:
--optimizer_config optimizer.json
or in train.json
"optimizer_config": "optimizer.json"
A default `optimizer.json` is supplied which you can modify
This has expanded tweaking. This doc is incomplete, but there is information on the web on betas and weight decay setting you can search for.
If you do not set optimizer_config, the defaults are `adamw8bit` with standard betas of `(0.9,0.999)`, weight decay `0.01`, and epsilon `1e-8`. The hyperparameters are originally from XavierXiao's Dreambooth code and based off Compvis Stable Diffusion code.
## Optimizers
In `optimizer.json` the `optimizer` value is the type of optimizer to use. Below are the supported optimizers.
* adamw
Standard full precision AdamW optimizer exposed by PyTorch. Not recommended. Slower and uses more memory than adamw8bit. Widely documented on the web.
* adamw8bit
Tim Dettmers / bitsandbytes AdamW 8bit optimizer. This is the default and recommended setting. Widely documented on the web.
* lion
Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the [lion optimizer](https://arxiv.org/abs/2302.06675). Click links to read more. Unknown what hyperparameters will work well, but paper shows potentially quicker learning. *Highly experimental, but tested and works.*
## Optimizer parameters
LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg.
Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here.
Note `lion` does not use epsilon.

15
optimizer.json Normal file
View File

@ -0,0 +1,15 @@
{
"doc": {
"optimizer": "adamw, adamw8bit, lion",
"optimizer_desc": "'adamw' in standard 32bit, 'adamw8bit' is bitsandbytes, 'lion' is lucidrains",
"lr": "learning rate, if null wil use CLI or main JSON config value",
"betas": "exponential decay rates for the moment estimates",
"epsilon": "value added to denominator for numerical stability, unused for lion",
"weight_decay": "weight decay (L2 penalty)"
},
"optimizer": "adamw8bit",
"lr:": null,
"betas": [0.9, 0.999],
"epsilon": 1e-8,
"weight_decay": 0.01
}

View File

@ -21,9 +21,11 @@
"lr_warmup_steps": null,
"max_epochs": 30,
"notebook": false,
"optimizer_config": "optimizer.json",
"project_name": "project_abc",
"resolution": 512,
"resume_ckpt": "sd_v1-5_vae",
"run_name": null,
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
"save_ckpt_dir": null,
@ -33,7 +35,6 @@
"scale_lr": false,
"seed": 555,
"shuffle_tags": false,
"useadam8bit": true,
"validation_config": null,
"wandb": false,
"write_schedule": false,

View File

@ -125,12 +125,12 @@ def setup_local_logger(args):
return datetimestamp
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay):
"""
logs the optimizer settings
"""
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
logging.info(f"{Fore.CYAN} betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
"""
@ -173,7 +173,6 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
if logs is not None:
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
def set_args_12gb(args):
logging.info(" Setting args to 12GB mode")
if not args.gradient_checkpointing:
@ -182,15 +181,10 @@ def set_args_12gb(args):
if args.batch_size != 1:
logging.info(" - Overiding batch size to 1")
args.batch_size = 1
# if args.grad_accum != 1:
# logging.info(" Overiding grad accum to 1")
args.grad_accum = 1
if args.resolution > 512:
logging.info(" - Overiding resolution to 512")
args.resolution = 512
if not args.useadam8bit:
logging.info(" - Overiding adam8bit to True")
args.useadam8bit = True
def find_last_checkpoint(logdir):
"""
@ -236,6 +230,9 @@ def setup_args(args):
args.clip_skip = max(min(4, args.clip_skip), 0)
if args.useadam8bit:
logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}")
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
args.ckpt_every_n_minutes = 20
@ -467,10 +464,6 @@ def main(args):
logging.info("xformers disabled, using attention slicing instead")
unet.set_attention_slice("auto")
default_lr = 2e-6
curr_lr = args.lr if args.lr is not None else default_lr
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
unet = unet.to(device, dtype=torch.float32)
if args.disable_textenc_training and args.amp:
@ -488,36 +481,63 @@ def main(args):
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, dir=args.logdir, config=args, name=args.run_name)
log_writer = SummaryWriter(log_dir=log_folder,
flush_secs=5,
comment="EveryDream2FineTunes",
comment=args.run_name if args.run_name is not None else "EveryDream2FineTunes",
)
betas = (0.9, 0.999)
betas = [0.9, 0.999]
epsilon = 1e-8
if args.amp:
epsilon = 1e-8
weight_decay = 0.01
if args.useadam8bit:
import bitsandbytes as bnb
opt_class = bnb.optim.AdamW8bit
logging.info(f"{Fore.CYAN} * Using AdamW 8-bit Optimizer *{Style.RESET_ALL}")
else:
opt_class = torch.optim.AdamW
logging.info(f"{Fore.CYAN} * Using AdamW standard Optimizer *{Style.RESET_ALL}")
optimizer = None
default_lr = 1e-6
curr_lr = args.lr
# open optimizer.json and override optimizer args
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
with open(os.path.join(os.curdir, optimizer_config_path), "r") as f:
optimizer_config = json.load(f)
betas = optimizer_config["betas"]
epsilon = optimizer_config["epsilon"]
weight_decay = optimizer_config["weight_decay"]
optimizer_name = optimizer_config["optimizer"]
curr_lr = optimizer_config.get("lr", curr_lr)
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
if curr_lr is None:
curr_lr = default_lr
if optimizer_name:
if optimizer_name == "lion":
from lion_pytorch import Lion
opt_class = Lion
optimizer = opt_class(
itertools.chain(params_to_train),
lr=curr_lr,
betas=betas,
betas=(betas[0], betas[1]),
weight_decay=weight_decay,
)
elif optimizer_name in ["adam8bit","adamw8bit"]:
import bitsandbytes as bnb
opt_class = bnb.optim.AdamW8bit
if not optimizer:
optimizer = opt_class(
itertools.chain(params_to_train),
lr=curr_lr,
betas=(betas[0], betas[1]),
eps=epsilon,
weight_decay=weight_decay,
amsgrad=False,
)
log_optimizer(optimizer, betas, epsilon)
log_optimizer(optimizer, betas, epsilon, weight_decay)
image_train_items = resolve_image_train_items(args, log_folder)
@ -564,10 +584,6 @@ def main(args):
num_training_steps=args.lr_decay_steps,
)
if args.wandb is not None and args.wandb:
wandb.init(project=args.project_name, sync_tensorboard=True, dir=args.logdir, config=args)
def log_args(log_writer, args):
arglog = "args:\n"
for arg, value in sorted(vars(args).items()):
@ -922,9 +938,11 @@ if __name__ == "__main__":
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
argparser.add_argument("--run_name", type=str, required=False, default=None, help="Run name for wandb (child of project name), and comment for tensorboard, (def: None)")
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="Text file with prompts to generate test samples from, or JSON file with sample generator settings (default: sample_prompts.txt)")
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
@ -935,7 +953,7 @@ if __name__ == "__main__":
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")

View File

@ -13,14 +13,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from pynvml.smi import nvidia_smi
from pynvml.smi import nvidia_smi as smi
import pynvml
import torch
class GPU:
def __init__(self, device: torch.device):
self.nvsmi = nvidia_smi.getInstance()
self.nvsmi = smi.getInstance()
self.device = device
def __querythis(self, query):
return gpu_query['gpu'][self.device.index]
def get_gpu_memory(self):
"""
returns a tuple of [gpu_used_mem, gpu_total_mem]
@ -30,3 +34,15 @@ class GPU:
gpu_used_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['used'])
gpu_total_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['total'])
return gpu_used_mem, gpu_total_mem
def supports_bfloat16(self):
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)
compute_compatibility = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
return compute_compatibility[0] >= 8
def driver_version(self):
gpu_query = self.nvsmi.DeviceQuery('driver_version')
driver_version = gpu_query['gpu'][self.device.index]['driver_version']
return driver_version