add huber loss, timestep clamping, slightly safer txt reading

This commit is contained in:
Victor Hall 2024-04-26 23:54:31 -04:00
parent f369e534eb
commit 3a6fe3b4a1
6 changed files with 94 additions and 34 deletions

2
.gitignore vendored
View File

@ -16,3 +16,5 @@
.idea
/.cache
/models
/*.safetensors
/*.webp

View File

@ -41,10 +41,6 @@ Covers install, setup of base models, startning training, basic tweaking, and lo
Behind the scenes look at how the trainer handles multiaspect and crop jitter
### Companion tools repo
Make sure to check out the [tools repo](https://github.com/victorchall/EveryDream), it has a grab bag of scripts to help with your data curation prior to training. It has automatic bulk BLIP captioning for BLIP, script to web scrape based on Laion data files, script to rename generic pronouns to proper names or append artist tags to your captions, etc.
## Cloud/Docker
### [Free tier Google Colab notebook](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/Train_Colab.ipynb)
@ -81,7 +77,7 @@ Make sure to check out the [tools repo](https://github.com/victorchall/EveryDrea
[Validation](doc/VALIDATION.md) - Use a validation split on your data to see when you are overfitting and tune hyperparameters
[Captioning](doc/CAPTION.md) - (beta) tools to automate captioning
[Captioning](doc/CAPTION.md) - tools to generate synthetic captioning (recommend [Cog](doc/CAPTION_COG.md))
[Plugins](doc/PLUGINS.md) - (beta) write your own plugins to execute arbitrary code during training

View File

@ -36,21 +36,45 @@ This is useful if you want to dump the CKPT files directly to your webui/inferen
## Conditional dropout
Conditional dropout means the prompt or caption on the training image is dropped, and the caption is "blank". The theory is this can help with unconditional guidance, per the original paper and authors of Latent Diffusion and Stable Diffusion.
Conditional dropout means the prompt or caption on the training image is dropped, and the caption is "blank". This can help with unconditional guidance, per the original paper and authors of Latent Diffusion and Stable Diffusion. This means the CFG Scale used at inference time will respond more smoothly.
The value is defaulted at 0.04, which means 4% conditional dropout. You can set it to 0.0 to disable it, or increase it. Many users of EveryDream 1.0 have had great success tweaking this, especially for larger models. You may wish to try 0.10. This may also be useful to really "force" a style into the model with a high setting such as 0.15. However, setting it very high may lead to bleeding or overfitting to your training data, especially if your data is not very diverse, which may or may not be desirable for your project.
The value is defaulted at 0.04, which means 4% conditional dropout. You can set it to 0.0 to disable it, or increase it. For larger training (many tens of thousands) using 0.10 would be my recommendation.
This may also be useful to really "force" a style into the model with a high setting such as 0.15. However, setting it very high may lead to bleeding or overfitting to your training data, especially if your data is not very diverse, which may or may not be desirable for your project.
--cond_dropout 0.1 ^
## Timestep clamping
Stable Diffusion uses 1000 possible timesteps for denoising steps. If you wish to train only a portion of those timesteps instead of the entire schedule you can clamp the value.
Timesteps are always chosen randomly per training example, per step, within the possible or allowed timesteps.
For instance, if you only want to train from 500 to 999, use this:
--timestep_start 500
Or if you only want to try from 0 to 449, use this:
--timestep_end 450
Possible use cases are to "focus" training on aesthetics or composition. It's likely you may need to train all timesteps as a "clean up" if you train just specific timestep ranges first.
## Loss Type
You can change the type of loss from the standard [MSE ("L2") loss](https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html) to [Huber loss](https://pytorch.org/docs/stable/generated/torch.nn.HuberLoss.html), or a interpolated value across timesteps. Valid values are "mse", "huber", "mse_huber", and "huber_mse".
--loss_type huber
mse_huber will use MSE at timestep 0 and huber at timestep 999, and interpolate between the two across the intermediate timesteps. huber_mse is the reverse
## LR tweaking
Learning rate adjustment is a very important part of training.
You should use [Optimizer config](doc/OPTIMZER.md) to tweak instead of the primary arg here, but it is left for legacy support of the Jupyter Notebook to make it easier to use the Jupyter Notbook in a happy path or simplified scenario.
--lr 1.0e-6 ^
By default, the learning rate is constant for the entire training session. However, if you want it to change by itself during training, you can use cosine.
General suggestion is 1e-6 for training SD1.5 at 512 resolution. For SD2.1 at 768, try a much lower value, such as 2e-7. [Validation](VALIDATION.md) can be helpful to tune learning rate.
*If you set this in train.json or the main CLI arg it will override the value from your optimizer.json, so use with caution...* Again, best to use optimizer.json instead.
## Clip skip

View File

@ -14,6 +14,7 @@
"grad_accum": 1,
"logdir": "logs",
"log_step": 25,
"loss_type": "mse",
"max_epochs": 40,
"notebook": false,
"optimizer_config": "optimizer.json",
@ -29,17 +30,17 @@
"save_optimizer": false,
"scale_lr": false,
"seed": 555,
"timestep_start": 0,
"timestep_end": 1000,
"shuffle_tags": false,
"validation_config": "validation_default.json",
"wandb": false,
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.02,
"pyramid_noise_discount": null,
"pyramid_noise_discount": 0.03,
"enable_zero_terminal_snr": false,
"load_settings_every_epoch": false,
"min_snr_gamma": null,
"min_snr_gamma": 5.0,
"ema_decay_rate": null,
"ema_strength_target": null,
"ema_update_interval": null,

View File

@ -391,6 +391,11 @@ def setup_args(args):
args.aspects = aspects.get_aspect_buckets(args.resolution)
if args.timestep_start < 0:
raise ValueError("timestep_start must be >= 0")
if args.timestep_end > 1000:
raise ValueError("timestep_end must be <= 1000")
return args
@ -727,16 +732,22 @@ def main(args):
text_encoder_ema = None
try:
#unet = torch.compile(unet)
#text_encoder = torch.compile(text_encoder)
#vae = torch.compile(vae)
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.allow_tf32 = True
print()
#unet = torch.compile(unet, mode="max-autotune")
#text_encoder = torch.compile(text_encoder, mode="max-autotune")
#vae = torch.compile(vae, mode="max-autotune")
#logging.info("Successfully compiled models")
except Exception as ex:
logging.warning(f"Failed to compile model, continuing anyway, ex: {ex}")
pass
try:
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.allow_tf32 = True
except Exception as ex:
logging.warning(f"Failed to set float32 matmul precision, continuing anyway, ex: {ex}")
pass
optimizer_config = None
optimizer_config_path = args.optimizer_config if args.optimizer_config else "optimizer.json"
if os.path.exists(os.path.join(os.curdir, optimizer_config_path)):
@ -944,7 +955,7 @@ def main(args):
bsz = latents.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = torch.randint(args.timestep_start, args.timestep_end, (bsz,), device=latents.device)
timesteps = timesteps.long()
cuda_caption = tokens.to(text_encoder.device)
@ -987,9 +998,32 @@ def main(args):
mse_loss_weights[snr == 0] = 1.0
loss_scale = loss_scale * mse_loss_weights.to(loss_scale.device)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device)
loss = loss.mean()
loss_mse = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss_scale = loss_scale.view(-1, 1, 1, 1).expand_as(loss_mse)
if args.loss_type == "mse_huber":
early_timestep_bias = (timesteps / noise_scheduler.config.num_train_timesteps)
early_timestep_bias = torch.tensor(early_timestep_bias, dtype=torch.float).to(unet.device)
early_timestep_bias = early_timestep_bias.view(-1, 1, 1, 1).expand_as(loss_mse)
loss_huber = F.huber_loss(model_pred.float(), target.float(), reduction="none", delta=1.0)
loss_mse = loss_mse * loss_scale.to(unet.device) * early_timestep_bias
loss_huber = loss_huber * loss_scale.to(unet.device) * (1.0 - early_timestep_bias)
loss = loss_mse.mean() + loss_huber.mean()
elif args.loss_type == "huber_mse":
early_timestep_bias = (timesteps / noise_scheduler.config.num_train_timesteps)
early_timestep_bias = torch.tensor(early_timestep_bias, dtype=torch.float).to(unet.device)
early_timestep_bias = early_timestep_bias.view(-1, 1, 1, 1).expand_as(loss_mse)
loss_huber = F.huber_loss(model_pred.float(), target.float(), reduction="none", delta=1.0)
loss_mse = loss_mse * loss_scale.to(unet.device) * (1.0 - early_timestep_bias)
loss_huber = loss_huber * loss_scale.to(unet.device) * early_timestep_bias
loss = loss_huber.mean() + loss_mse.mean()
elif args.loss_type == "huber":
loss_huber = F.huber_loss(model_pred.float(), target.float(), reduction="none", delta=1.0)
loss_huber = loss_huber * loss_scale.to(unet.device)
loss = loss_huber.mean()
else:
loss_mse = loss_mse * loss_scale.to(unet.device)
loss = loss_mse.mean()
return model_pred, target, loss
@ -1334,6 +1368,7 @@ if __name__ == "__main__":
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
argparser.add_argument("--log_step", type=int, default=25, help="How often to log training stats, def: 25, recommend default!")
argparser.add_argument("--loss_type", type=str, default="mse_huber", help="type of loss, 'huber', 'mse', or 'mse_huber' for interpolated (def: mse_huber)", choices=["huber", "mse", "mse_huber"])
argparser.add_argument("--lr", type=float, default=None, help="Learning rate, if using scheduler is maximum LR at top of curve")
argparser.add_argument("--lr_decay_steps", type=int, default=0, help="Steps to reach minimum LR, default: automatically set")
argparser.add_argument("--lr_scheduler", type=str, default="constant", help="LR scheduler, (default: constant)", choices=["constant", "linear", "cosine", "polynomial"])
@ -1356,6 +1391,8 @@ if __name__ == "__main__":
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--timestep_start", type=int, default=0, help="Noising timestep minimum (def: 0)")
argparser.add_argument("--timestep_end", type=int, default=1000, help="Noising timestep (def: 1000)")
argparser.add_argument("--train_sampler", type=str, default="ddpm", help="noise sampler used for training, (default: ddpm)", choices=["ddpm", "pndm", "ddim"])
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, used to randomly select subset of tags when shuffling is enabled, def: 0 (shuffle all)")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")

View File

@ -16,16 +16,16 @@ def is_image(file):
def read_text(file):
try:
with open(file, encoding='utf-8', mode='r') as stream:
return stream.read().strip()
encodings = ['utf-8', 'iso-8859-1', 'windows-1252', 'latin-1']
for encoding in encodings:
try:
with open(file, encoding=encoding) as f:
return f.read()
except UnicodeDecodeError:
continue
raise UnicodeDecodeError(f'Could not decode file with any of the provided encodings: {encodings}')
except Exception as e:
logging.warning(f" *** Error reading text file as utf-8: {file}: {e}")
try:
with open(file, encoding='latin-1', mode='r') as stream:
return stream.read().strip()
except Exception as e:
logging.warning(f" *** Error reading text file as latin-1: {file}: {e}")
logging.warning(f" *** Error reading text file: {file}: {e}")
def read_float(file):
try: