Merge pull request #92 from victorchall/lion
fix mem leak on huge data, rework optimizer to separate json, add lio…
This commit is contained in:
commit
e7fc71ffa1
|
@ -35,6 +35,8 @@ Make sure to check out the [tools repo](https://github.com/victorchall/EveryDrea
|
|||
|
||||
[Advanced Tweaking](doc/ATWEAKING.md) - More stuff to tweak once you are comfortable
|
||||
|
||||
[Advanced Optimizer Tweaking](/doc/OPTIMIZER.md) - Even more stuff to tweak if you are *very adventurous*
|
||||
|
||||
[Chaining training sessions](doc/CHAINING.md) - Modify training parameters by chaining training sessions together end to end
|
||||
|
||||
[Shuffling Tags](doc/SHUFFLING_TAGS.md)
|
||||
|
@ -49,4 +51,6 @@ Make sure to check out the [tools repo](https://github.com/victorchall/EveryDrea
|
|||
|
||||
[Free tier Google Colab notebook](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/Train_Colab.ipynb)
|
||||
|
||||
[RunPod / Vast](/doc/CLOUD_SETUP.md)
|
||||
[RunPod / Vast](/doc/CLOUD_SETUP.md)
|
||||
|
||||
[Docker image link](https://github.com/victorchall/EveryDream2trainer/pkgs/container/everydream2trainer)
|
|
@ -0,0 +1,44 @@
|
|||
{
|
||||
"amp": true,
|
||||
"batch_size": 10,
|
||||
"ckpt_every_n_minutes": null,
|
||||
"clip_grad_norm": null,
|
||||
"clip_skip": 0,
|
||||
"cond_dropout": 1.00,
|
||||
"data_root": "Q:/mldata/coco2017/train2017",
|
||||
"disable_textenc_training": false,
|
||||
"disable_xformers": false,
|
||||
"flip_p": 0.0,
|
||||
"gpuid": 0,
|
||||
"gradient_checkpointing": false,
|
||||
"grad_accum": 1,
|
||||
"logdir": "logs",
|
||||
"log_step": 25,
|
||||
"lowvram": false,
|
||||
"lr": 1.0e-06,
|
||||
"lr_decay_steps": 0,
|
||||
"lr_scheduler": "constant",
|
||||
"lr_warmup_steps": null,
|
||||
"max_epochs": 30,
|
||||
"notebook": false,
|
||||
"optimizer_config": "optimizer.json",
|
||||
"project_name": "coco_test",
|
||||
"resolution": 512,
|
||||
"resume_ckpt": "SD15",
|
||||
"run_name": null,
|
||||
"sample_prompts": "sample_prompts.txt",
|
||||
"sample_steps": 300,
|
||||
"save_ckpt_dir": null,
|
||||
"save_ckpts_from_n_epochs": 0,
|
||||
"save_every_n_epochs": 20,
|
||||
"save_optimizer": false,
|
||||
"scale_lr": false,
|
||||
"seed": 555,
|
||||
"shuffle_tags": false,
|
||||
"validation_config": "cfgs/validation_coco.json",
|
||||
"wandb": true,
|
||||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
"rated_dataset_target_dropout_percent": 50,
|
||||
"zero_frequency_noise_ratio": 0.04
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
{
|
||||
"documentation": {
|
||||
"validate_training": "If true, validate the training using a separate set of image/caption pairs, and log the results as `loss/val`. The curve will trend downwards as the model trains, then flatten and start to trend upwards as effective training finishes and the model begins to overfit the training data. Very useful for preventing overfitting, for checking if your learning rate is too low or too high, and for deciding when to stop training.",
|
||||
"val_split_mode": "Either 'automatic' or 'manual', ignored if validate_training is false. 'automatic' val_split_mode picks a random subset of the training set (the number of items is controlled by val_split_proportion) and removes them from training to use as a validation set. 'manual' val_split_mode lets you provide your own folder of validation items (images+captions), specified using 'val_data_root'.",
|
||||
"val_split_proportion": "For 'automatic' val_split_mode, how much of the train dataset that should be removed to use for validation. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"val_data_root": "For 'manual' val_split_mode, the path to a folder containing validation items.",
|
||||
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
|
||||
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"every_n_epochs": "How often to run validation (1=every epoch).",
|
||||
"seed": "The seed to use when running validation and stabilization passes."
|
||||
},
|
||||
"validate_training":"true",
|
||||
"val_split_mode": "manual",
|
||||
"val_data_root": "Q:/mldata/coco2017/val2017",
|
||||
"val_split_proportion": 0.15,
|
||||
"stabilize_training_loss": true,
|
||||
"stabilize_split_proportion": 0.15,
|
||||
"every_n_epochs": 1,
|
||||
"seed": 555
|
||||
}
|
|
@ -129,9 +129,11 @@ class EveryDreamBatch(Dataset):
|
|||
|
||||
image_train_tmp = image_train_item.hydrate(crop=False, save=save, crop_jitter=self.crop_jitter)
|
||||
|
||||
example["image"] = image_train_tmp.image
|
||||
example["image"] = image_train_tmp.image.copy() # hack for now to avoid memory leak
|
||||
image_train_tmp.image = None # hack for now to avoid memory leak
|
||||
example["caption"] = image_train_tmp.caption
|
||||
example["runt_size"] = image_train_tmp.runt_size
|
||||
|
||||
return example
|
||||
|
||||
def __update_image_train_items(self, dropout_fraction: float):
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# Advanced optimizer tweaking
|
||||
|
||||
You can set advanced optimizer settings using this arg:
|
||||
|
||||
--optimizer_config optimizer.json
|
||||
|
||||
or in train.json
|
||||
|
||||
"optimizer_config": "optimizer.json"
|
||||
|
||||
A default `optimizer.json` is supplied which you can modify
|
||||
|
||||
This has expanded tweaking. This doc is incomplete, but there is information on the web on betas and weight decay setting you can search for.
|
||||
|
||||
If you do not set optimizer_config, the defaults are `adamw8bit` with standard betas of `(0.9,0.999)`, weight decay `0.01`, and epsilon `1e-8`. The hyperparameters are originally from XavierXiao's Dreambooth code and based off Compvis Stable Diffusion code.
|
||||
|
||||
## Optimizers
|
||||
|
||||
In `optimizer.json` the `optimizer` value is the type of optimizer to use. Below are the supported optimizers.
|
||||
|
||||
* adamw
|
||||
|
||||
Standard full precision AdamW optimizer exposed by PyTorch. Not recommended. Slower and uses more memory than adamw8bit. Widely documented on the web.
|
||||
|
||||
* adamw8bit
|
||||
|
||||
Tim Dettmers / bitsandbytes AdamW 8bit optimizer. This is the default and recommended setting. Widely documented on the web.
|
||||
|
||||
* lion
|
||||
|
||||
Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the [lion optimizer](https://arxiv.org/abs/2302.06675). Click links to read more. Unknown what hyperparameters will work well, but paper shows potentially quicker learning. *Highly experimental, but tested and works.*
|
||||
|
||||
## Optimizer parameters
|
||||
|
||||
LR can be set in `optimizer.json` and excluded from the main CLI arg or train.json but if you use the main CLI arg or set it in the main train.json it will override the setting. This was done to make sure existing behavior will not break. To set LR in the `optimizer.json` make sure to delete `"lr": 1.3e-6` in your main train.json and exclude the CLI arg.
|
||||
|
||||
Betas, weight decay, and epsilon are documented in the [AdamW paper](https://arxiv.org/abs/1711.05101) and there is a wealth of information on the web, but consider those experimental to tweak. I cannot provide advice on what might be useful to tweak here.
|
||||
|
||||
Note `lion` does not use epsilon.
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
"doc": {
|
||||
"optimizer": "adamw, adamw8bit, lion",
|
||||
"optimizer_desc": "'adamw' in standard 32bit, 'adamw8bit' is bitsandbytes, 'lion' is lucidrains",
|
||||
"lr": "learning rate, if null wil use CLI or main JSON config value",
|
||||
"betas": "exponential decay rates for the moment estimates",
|
||||
"epsilon": "value added to denominator for numerical stability, unused for lion",
|
||||
"weight_decay": "weight decay (L2 penalty)"
|
||||
},
|
||||
"optimizer": "adamw8bit",
|
||||
"lr:": null,
|
||||
"betas": [0.9, 0.999],
|
||||
"epsilon": 1e-8,
|
||||
"weight_decay": 0.01
|
||||
}
|
|
@ -21,9 +21,11 @@
|
|||
"lr_warmup_steps": null,
|
||||
"max_epochs": 30,
|
||||
"notebook": false,
|
||||
"optimizer_config": "optimizer.json",
|
||||
"project_name": "project_abc",
|
||||
"resolution": 512,
|
||||
"resume_ckpt": "sd_v1-5_vae",
|
||||
"run_name": null,
|
||||
"sample_prompts": "sample_prompts.txt",
|
||||
"sample_steps": 300,
|
||||
"save_ckpt_dir": null,
|
||||
|
@ -33,7 +35,6 @@
|
|||
"scale_lr": false,
|
||||
"seed": 555,
|
||||
"shuffle_tags": false,
|
||||
"useadam8bit": true,
|
||||
"validation_config": null,
|
||||
"wandb": false,
|
||||
"write_schedule": false,
|
||||
|
|
90
train.py
90
train.py
|
@ -125,12 +125,12 @@ def setup_local_logger(args):
|
|||
|
||||
return datetimestamp
|
||||
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon):
|
||||
def log_optimizer(optimizer: torch.optim.Optimizer, betas, epsilon, weight_decay):
|
||||
"""
|
||||
logs the optimizer settings
|
||||
"""
|
||||
logging.info(f"{Fore.CYAN} * Optimizer: {optimizer.__class__.__name__} *{Style.RESET_ALL}")
|
||||
logging.info(f" betas: {betas}, epsilon: {epsilon} *{Style.RESET_ALL}")
|
||||
logging.info(f"{Fore.CYAN} betas: {betas}, epsilon: {epsilon}, weight_decay: {weight_decay} *{Style.RESET_ALL}")
|
||||
|
||||
def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
|
@ -173,7 +173,6 @@ def append_epoch_log(global_step: int, epoch_pbar, gpu, log_writer, **logs):
|
|||
if logs is not None:
|
||||
epoch_pbar.set_postfix(**logs, vram=f"{epoch_mem_color}{gpu_used_mem}/{gpu_total_mem} MB{Style.RESET_ALL} gs:{global_step}")
|
||||
|
||||
|
||||
def set_args_12gb(args):
|
||||
logging.info(" Setting args to 12GB mode")
|
||||
if not args.gradient_checkpointing:
|
||||
|
@ -182,15 +181,10 @@ def set_args_12gb(args):
|
|||
if args.batch_size != 1:
|
||||
logging.info(" - Overiding batch size to 1")
|
||||
args.batch_size = 1
|
||||
# if args.grad_accum != 1:
|
||||
# logging.info(" Overiding grad accum to 1")
|
||||
args.grad_accum = 1
|
||||
if args.resolution > 512:
|
||||
logging.info(" - Overiding resolution to 512")
|
||||
args.resolution = 512
|
||||
if not args.useadam8bit:
|
||||
logging.info(" - Overiding adam8bit to True")
|
||||
args.useadam8bit = True
|
||||
|
||||
def find_last_checkpoint(logdir):
|
||||
"""
|
||||
|
@ -236,6 +230,9 @@ def setup_args(args):
|
|||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
if args.useadam8bit:
|
||||
logging.warning(f"{Fore.LIGHTYELLOW_EX} Useadam8bit arg is deprecated, use optimizer.json instead, which defaults to useadam8bit anyway{Style.RESET_ALL}")
|
||||
|
||||
if args.ckpt_every_n_minutes is None and args.save_every_n_epochs is None:
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} No checkpoint saving specified, defaulting to every 20 minutes.{Style.RESET_ALL}")
|
||||
args.ckpt_every_n_minutes = 20
|
||||
|
@ -453,7 +450,7 @@ def main(args):
|
|||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
|
||||
if not args.disable_xformers:
|
||||
if (args.amp and is_sd1attn) or (not is_sd1attn):
|
||||
try:
|
||||
|
@ -467,10 +464,6 @@ def main(args):
|
|||
logging.info("xformers disabled, using attention slicing instead")
|
||||
unet.set_attention_slice("auto")
|
||||
|
||||
default_lr = 2e-6
|
||||
curr_lr = args.lr if args.lr is not None else default_lr
|
||||
|
||||
|
||||
vae = vae.to(device, dtype=torch.float16 if args.amp else torch.float32)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
if args.disable_textenc_training and args.amp:
|
||||
|
@ -488,36 +481,63 @@ def main(args):
|
|||
logging.info(f"{Fore.CYAN} * Training Text and Unet *{Style.RESET_ALL}")
|
||||
params_to_train = itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
|
||||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, dir=args.logdir, config=args, name=args.run_name)
|
||||
|
||||
log_writer = SummaryWriter(log_dir=log_folder,
|
||||
flush_secs=5,
|
||||
comment="EveryDream2FineTunes",
|
||||
)
|
||||
flush_secs=5,
|
||||
comment=args.run_name if args.run_name is not None else "EveryDream2FineTunes",
|
||||
)
|
||||
|
||||
betas = (0.9, 0.999)
|
||||
betas = [0.9, 0.999]
|
||||
epsilon = 1e-8
|
||||
if args.amp:
|
||||
epsilon = 1e-8
|
||||
|
||||
weight_decay = 0.01
|
||||
if args.useadam8bit:
|
||||
import bitsandbytes as bnb
|
||||
opt_class = bnb.optim.AdamW8bit
|
||||
logging.info(f"{Fore.CYAN} * Using AdamW 8-bit Optimizer *{Style.RESET_ALL}")
|
||||
else:
|
||||
opt_class = torch.optim.AdamW
|
||||
logging.info(f"{Fore.CYAN} * Using AdamW standard Optimizer *{Style.RESET_ALL}")
|
||||
opt_class = torch.optim.AdamW
|
||||
optimizer = None
|
||||
|
||||
optimizer = opt_class(
|
||||
default_lr = 1e-6
|
||||
curr_lr = args.lr
|
||||
|
||||
# open optimizer.json and override optimizer args
|
||||
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
|
||||
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
|
||||
with open(os.path.join(os.curdir, optimizer_config_path), "r") as f:
|
||||
optimizer_config = json.load(f)
|
||||
betas = optimizer_config["betas"]
|
||||
epsilon = optimizer_config["epsilon"]
|
||||
weight_decay = optimizer_config["weight_decay"]
|
||||
optimizer_name = optimizer_config["optimizer"]
|
||||
curr_lr = optimizer_config.get("lr", curr_lr)
|
||||
logging.info(f" * Loaded optimizer args from {optimizer_config_path} *")
|
||||
|
||||
if curr_lr is None:
|
||||
curr_lr = default_lr
|
||||
|
||||
if optimizer_name:
|
||||
if optimizer_name == "lion":
|
||||
from lion_pytorch import Lion
|
||||
opt_class = Lion
|
||||
optimizer = opt_class(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
elif optimizer_name in ["adam8bit","adamw8bit"]:
|
||||
import bitsandbytes as bnb
|
||||
opt_class = bnb.optim.AdamW8bit
|
||||
|
||||
if not optimizer:
|
||||
optimizer = opt_class(
|
||||
itertools.chain(params_to_train),
|
||||
lr=curr_lr,
|
||||
betas=betas,
|
||||
betas=(betas[0], betas[1]),
|
||||
eps=epsilon,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=False,
|
||||
)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon)
|
||||
|
||||
log_optimizer(optimizer, betas, epsilon, weight_decay)
|
||||
|
||||
image_train_items = resolve_image_train_items(args, log_folder)
|
||||
|
||||
|
@ -564,10 +584,6 @@ def main(args):
|
|||
num_training_steps=args.lr_decay_steps,
|
||||
)
|
||||
|
||||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name, sync_tensorboard=True, dir=args.logdir, config=args)
|
||||
|
||||
|
||||
def log_args(log_writer, args):
|
||||
arglog = "args:\n"
|
||||
for arg, value in sorted(vars(args).items()):
|
||||
|
@ -922,9 +938,11 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--lr_warmup_steps", type=int, default=None, help="Steps to reach max LR during warmup (def: 0.02 of lr_decay_steps), non-functional for constant")
|
||||
argparser.add_argument("--max_epochs", type=int, default=300, help="Maximum number of epochs to train for")
|
||||
argparser.add_argument("--notebook", action="store_true", default=False, help="disable keypresses and uses tqdm.notebook for jupyter notebook (def: False)")
|
||||
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||
argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
|
||||
argparser.add_argument("--run_name", type=str, required=False, default=None, help="Run name for wandb (child of project name), and comment for tensorboard, (def: None)")
|
||||
argparser.add_argument("--sample_prompts", type=str, default="sample_prompts.txt", help="Text file with prompts to generate test samples from, or JSON file with sample generator settings (default: sample_prompts.txt)")
|
||||
argparser.add_argument("--sample_steps", type=int, default=250, help="Number of steps between samples (def: 250)")
|
||||
argparser.add_argument("--save_ckpt_dir", type=str, default=None, help="folder to save checkpoints to (def: root training folder)")
|
||||
|
@ -935,7 +953,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--scale_lr", action="store_true", default=False, help="automatically scale up learning rate based on batch size and grad accumulation (def: False)")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="Use AdamW 8-Bit optimizer, recommended!")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
|
|
22
utils/gpu.py
22
utils/gpu.py
|
@ -13,13 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
from pynvml.smi import nvidia_smi
|
||||
from pynvml.smi import nvidia_smi as smi
|
||||
import pynvml
|
||||
import torch
|
||||
|
||||
class GPU:
|
||||
def __init__(self, device: torch.device):
|
||||
self.nvsmi = nvidia_smi.getInstance()
|
||||
self.nvsmi = smi.getInstance()
|
||||
self.device = device
|
||||
|
||||
def __querythis(self, query):
|
||||
return gpu_query['gpu'][self.device.index]
|
||||
|
||||
def get_gpu_memory(self):
|
||||
"""
|
||||
|
@ -29,4 +33,16 @@ class GPU:
|
|||
#print(gpu_query)
|
||||
gpu_used_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['used'])
|
||||
gpu_total_mem = int(gpu_query['gpu'][self.device.index]['fb_memory_usage']['total'])
|
||||
return gpu_used_mem, gpu_total_mem
|
||||
return gpu_used_mem, gpu_total_mem
|
||||
|
||||
def supports_bfloat16(self):
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(self.device.index)
|
||||
compute_compatibility = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
|
||||
return compute_compatibility[0] >= 8
|
||||
|
||||
def driver_version(self):
|
||||
gpu_query = self.nvsmi.DeviceQuery('driver_version')
|
||||
driver_version = gpu_query['gpu'][self.device.index]['driver_version']
|
||||
return driver_version
|
||||
|
Loading…
Reference in New Issue