diff --git a/data/dataset.py b/data/dataset.py index a478a46..42689e7 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -54,6 +54,7 @@ class ImageConfig: cond_dropout: float = None flip_p: float = None shuffle_tags: bool = False + loss_scale: float = None def merge(self, other): if other is None: @@ -68,7 +69,8 @@ class ImageConfig: cond_dropout=overlay(other.cond_dropout, self.cond_dropout), flip_p=overlay(other.flip_p, self.flip_p), shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags), - batch_id=overlay(other.batch_id, self.batch_id) + batch_id=overlay(other.batch_id, self.batch_id), + loss_scale=overlay(other.loss_scale, self.loss_scale) ) @classmethod @@ -83,7 +85,8 @@ class ImageConfig: cond_dropout=data.get("cond_dropout"), flip_p=data.get("flip_p"), shuffle_tags=data.get("shuffle_tags"), - batch_id=data.get("batch_id") + batch_id=data.get("batch_id"), + loss_scale=data.get("loss_scale") ) # Alternatively parse from dedicated `caption` attribute @@ -170,6 +173,8 @@ class Dataset: cfgs.append(ImageConfig.from_file(fileset['local.yml'])) if 'batch_id.txt' in fileset: cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt']))) + if 'loss_scale.txt' in fileset: + cfgs.append(ImageConfig(loss_scale=read_float(fileset['loss_scale.txt']))) result = ImageConfig.fold(cfgs) if 'shuffle_tags.txt' in fileset: @@ -264,7 +269,8 @@ class Dataset: multiplier=config.multiply or 1.0, cond_dropout=config.cond_dropout, shuffle_tags=config.shuffle_tags, - batch_id=config.batch_id + batch_id=config.batch_id, + loss_scale=config.loss_scale ) items.append(item) except Exception as e: diff --git a/data/every_dream.py b/data/every_dream.py index d99ebb4..2ffc4c4 100644 --- a/data/every_dream.py +++ b/data/every_dream.py @@ -118,6 +118,7 @@ class EveryDreamBatch(Dataset): example["tokens"] = torch.tensor(example["tokens"]) example["runt_size"] = train_item["runt_size"] + example["loss_scale"] = train_item["loss_scale"] return example @@ -134,6 +135,7 @@ class EveryDreamBatch(Dataset): example["cond_dropout"] = image_train_tmp.cond_dropout example["runt_size"] = image_train_tmp.runt_size example["shuffle_tags"] = image_train_tmp.shuffle_tags + example["loss_scale"] = image_train_tmp.loss_scale return example @@ -214,11 +216,14 @@ def collate_fn(batch): images = torch.stack(images) images = images.to(memory_format=torch.contiguous_format).float() + loss_scale = torch.tensor([example.get("loss_scale", 1) for example in batch]) + ret = { "tokens": torch.stack(tuple(tokens)), "image": images, "captions": captions, "runt_size": runt_size, + "loss_scale": loss_scale } del batch return ret diff --git a/data/image_train_item.py b/data/image_train_item.py index a8979c3..4d6b377 100644 --- a/data/image_train_item.py +++ b/data/image_train_item.py @@ -138,10 +138,11 @@ class ImageTrainItem: aspects: list[float], pathname: str, flip_p=0.0, - multiplier: float=1.0, + multiplier: float=1.0, cond_dropout=None, shuffle_tags=False, - batch_id: str=None + batch_id: str=None, + loss_scale: float=None ): self.caption = caption self.aspects = aspects @@ -153,6 +154,7 @@ class ImageTrainItem: self.cond_dropout = cond_dropout self.shuffle_tags = shuffle_tags self.batch_id = batch_id or DEFAULT_BATCH_ID + self.loss_scale = loss_scale or 1 self.target_wh = None self.image_size = None diff --git a/train.py b/train.py index d6a44a0..b95cd1b 100644 --- a/train.py +++ b/train.py @@ -1138,8 +1138,12 @@ def main(args): del target, model_pred if batch["runt_size"] > 0: - loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5 - loss = loss * loss_scale + runt_loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5 + loss = loss * runt_loss_scale + + if "loss_scale" in batch.keys(): + # Apply the mask to the loss + loss = loss * batch["loss_scale"] ed_optimizer.step(loss, step, global_step)