From da731268b2615336307e6145b7cd1005851f14e8 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Tue, 31 Oct 2023 10:06:21 +0100 Subject: [PATCH 1/4] put a file loss_scale.txt containing a float in a training folder to apply loss scale (eg -1 for negative examples) --- data/dataset.py | 12 +++++++++--- data/every_dream.py | 5 +++++ data/image_train_item.py | 6 ++++-- train.py | 8 ++++++-- 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/data/dataset.py b/data/dataset.py index a478a46..42689e7 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -54,6 +54,7 @@ class ImageConfig: cond_dropout: float = None flip_p: float = None shuffle_tags: bool = False + loss_scale: float = None def merge(self, other): if other is None: @@ -68,7 +69,8 @@ class ImageConfig: cond_dropout=overlay(other.cond_dropout, self.cond_dropout), flip_p=overlay(other.flip_p, self.flip_p), shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags), - batch_id=overlay(other.batch_id, self.batch_id) + batch_id=overlay(other.batch_id, self.batch_id), + loss_scale=overlay(other.loss_scale, self.loss_scale) ) @classmethod @@ -83,7 +85,8 @@ class ImageConfig: cond_dropout=data.get("cond_dropout"), flip_p=data.get("flip_p"), shuffle_tags=data.get("shuffle_tags"), - batch_id=data.get("batch_id") + batch_id=data.get("batch_id"), + loss_scale=data.get("loss_scale") ) # Alternatively parse from dedicated `caption` attribute @@ -170,6 +173,8 @@ class Dataset: cfgs.append(ImageConfig.from_file(fileset['local.yml'])) if 'batch_id.txt' in fileset: cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt']))) + if 'loss_scale.txt' in fileset: + cfgs.append(ImageConfig(loss_scale=read_float(fileset['loss_scale.txt']))) result = ImageConfig.fold(cfgs) if 'shuffle_tags.txt' in fileset: @@ -264,7 +269,8 @@ class Dataset: multiplier=config.multiply or 1.0, cond_dropout=config.cond_dropout, shuffle_tags=config.shuffle_tags, - batch_id=config.batch_id + batch_id=config.batch_id, + loss_scale=config.loss_scale ) items.append(item) except Exception as e: diff --git a/data/every_dream.py b/data/every_dream.py index d99ebb4..2ffc4c4 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -118,6 +118,7 @@ class EveryDreamBatch(Dataset): example["tokens"] = torch.tensor(example["tokens"]) example["runt_size"] = train_item["runt_size"] + example["loss_scale"] = train_item["loss_scale"] return example @@ -134,6 +135,7 @@ class EveryDreamBatch(Dataset): example["cond_dropout"] = image_train_tmp.cond_dropout example["runt_size"] = image_train_tmp.runt_size example["shuffle_tags"] = image_train_tmp.shuffle_tags + example["loss_scale"] = image_train_tmp.loss_scale return example @@ -214,11 +216,14 @@ def collate_fn(batch): images = torch.stack(images) images = images.to(memory_format=torch.contiguous_format).float() + loss_scale = torch.tensor([example.get("loss_scale", 1) for example in batch]) + ret = { "tokens": torch.stack(tuple(tokens)), "image": images, "captions": captions, "runt_size": runt_size, + "loss_scale": loss_scale } del batch return ret diff --git a/data/image_train_item.py b/data/image_train_item.py index a8979c3..4d6b377 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -138,10 +138,11 @@ class ImageTrainItem: aspects: list[float], pathname: str, flip_p=0.0, - multiplier: float=1.0, + multiplier: float=1.0, cond_dropout=None, shuffle_tags=False, - batch_id: str=None + batch_id: str=None, + loss_scale: float=None ): self.caption = caption self.aspects = aspects @@ -153,6 +154,7 @@ class ImageTrainItem: self.cond_dropout = cond_dropout self.shuffle_tags = shuffle_tags self.batch_id = batch_id or DEFAULT_BATCH_ID + self.loss_scale = loss_scale or 1 self.target_wh = None self.image_size = None diff --git a/train.py b/train.py index d6a44a0..b95cd1b 100644 --- a/train.py +++ b/train.py @@ -1138,8 +1138,12 @@ def main(args): del target, model_pred if batch["runt_size"] > 0: - loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5 - loss = loss * loss_scale + runt_loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5 + loss = loss * runt_loss_scale + + if "loss_scale" in batch.keys(): + # Apply the mask to the loss + loss = loss * batch["loss_scale"] ed_optimizer.step(loss, step, global_step) From a7343ad1908c2b176c22a667ffabd9adf59de429 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 1 Nov 2023 08:11:42 +0100 Subject: [PATCH 2/4] fix scale batch calculation --- train.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/train.py b/train.py index b95cd1b..a8c867b 100644 --- a/train.py +++ b/train.py @@ -901,7 +901,7 @@ def main(args): assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct" # actual prediction function - shared between train and validate - def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False): + def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None): with torch.no_grad(): with autocast(enabled=args.amp): pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device) @@ -947,10 +947,10 @@ def main(args): model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample if return_loss: - if args.min_snr_gamma is None: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + if loss_scale is None: + loss_scale = torch.ones(model_pred.shape[0], dtype=torch.float) - else: + if args.min_snr_gamma is not None: snr = compute_snr(timesteps, noise_scheduler) mse_loss_weights = ( @@ -960,9 +960,11 @@ def main(args): / snr ) mse_loss_weights[snr == 0] = 1.0 - loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights - loss = loss.mean() + loss_scale = loss_scale * mse_loss_weights + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale + loss = loss.mean() return model_pred, target, loss @@ -1133,7 +1135,11 @@ def main(args): batch=batch, ed_state=make_current_ed_state()) - model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True) + model_pred, target, loss = get_model_prediction_and_target(batch["image"], + batch["tokens"], + args.zero_frequency_noise_ratio, + return_loss=True, + loss_scale=batch["loss_scale"]) del target, model_pred @@ -1141,10 +1147,6 @@ def main(args): runt_loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5 loss = loss * runt_loss_scale - if "loss_scale" in batch.keys(): - # Apply the mask to the loss - loss = loss * batch["loss_scale"] - ed_optimizer.step(loss, step, global_step) if args.ema_decay_rate != None: From c485d4ea6095333aff4160ec44ff4305dd39e7d4 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 1 Nov 2023 09:29:41 +0100 Subject: [PATCH 3/4] fix device mismatch with loss_scale --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index a8c867b..979d8b2 100644 --- a/train.py +++ b/train.py @@ -963,7 +963,7 @@ def main(args): loss_scale = loss_scale * mse_loss_weights loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") - loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device) loss = loss.mean() return model_pred, target, loss From 86aaf1c4d71e98a3616804a540a9d40532c5f2fb Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Wed, 1 Nov 2023 19:00:08 +0100 Subject: [PATCH 4/4] fix big when loss_scale.txt contains 0 --- data/image_train_item.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/image_train_item.py b/data/image_train_item.py index 4d6b377..1f9c7ea 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -154,7 +154,7 @@ class ImageTrainItem: self.cond_dropout = cond_dropout self.shuffle_tags = shuffle_tags self.batch_id = batch_id or DEFAULT_BATCH_ID - self.loss_scale = loss_scale or 1 + self.loss_scale = 1 if loss_scale is None else loss_scale self.target_wh = None self.image_size = None