put a file loss_scale.txt containing a float in a training folder to apply loss scale (eg -1 for negative examples)

This commit is contained in:
Damian Stewart 2023-10-31 10:06:21 +01:00
parent 2f2dd4c1f2
commit da731268b2
4 changed files with 24 additions and 7 deletions

View File

@ -54,6 +54,7 @@ class ImageConfig:
cond_dropout: float = None cond_dropout: float = None
flip_p: float = None flip_p: float = None
shuffle_tags: bool = False shuffle_tags: bool = False
loss_scale: float = None
def merge(self, other): def merge(self, other):
if other is None: if other is None:
@ -68,7 +69,8 @@ class ImageConfig:
cond_dropout=overlay(other.cond_dropout, self.cond_dropout), cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
flip_p=overlay(other.flip_p, self.flip_p), flip_p=overlay(other.flip_p, self.flip_p),
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags), shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
batch_id=overlay(other.batch_id, self.batch_id) batch_id=overlay(other.batch_id, self.batch_id),
loss_scale=overlay(other.loss_scale, self.loss_scale)
) )
@classmethod @classmethod
@ -83,7 +85,8 @@ class ImageConfig:
cond_dropout=data.get("cond_dropout"), cond_dropout=data.get("cond_dropout"),
flip_p=data.get("flip_p"), flip_p=data.get("flip_p"),
shuffle_tags=data.get("shuffle_tags"), shuffle_tags=data.get("shuffle_tags"),
batch_id=data.get("batch_id") batch_id=data.get("batch_id"),
loss_scale=data.get("loss_scale")
) )
# Alternatively parse from dedicated `caption` attribute # Alternatively parse from dedicated `caption` attribute
@ -170,6 +173,8 @@ class Dataset:
cfgs.append(ImageConfig.from_file(fileset['local.yml'])) cfgs.append(ImageConfig.from_file(fileset['local.yml']))
if 'batch_id.txt' in fileset: if 'batch_id.txt' in fileset:
cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt']))) cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt'])))
if 'loss_scale.txt' in fileset:
cfgs.append(ImageConfig(loss_scale=read_float(fileset['loss_scale.txt'])))
result = ImageConfig.fold(cfgs) result = ImageConfig.fold(cfgs)
if 'shuffle_tags.txt' in fileset: if 'shuffle_tags.txt' in fileset:
@ -264,7 +269,8 @@ class Dataset:
multiplier=config.multiply or 1.0, multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout, cond_dropout=config.cond_dropout,
shuffle_tags=config.shuffle_tags, shuffle_tags=config.shuffle_tags,
batch_id=config.batch_id batch_id=config.batch_id,
loss_scale=config.loss_scale
) )
items.append(item) items.append(item)
except Exception as e: except Exception as e:

View File

@ -118,6 +118,7 @@ class EveryDreamBatch(Dataset):
example["tokens"] = torch.tensor(example["tokens"]) example["tokens"] = torch.tensor(example["tokens"])
example["runt_size"] = train_item["runt_size"] example["runt_size"] = train_item["runt_size"]
example["loss_scale"] = train_item["loss_scale"]
return example return example
@ -134,6 +135,7 @@ class EveryDreamBatch(Dataset):
example["cond_dropout"] = image_train_tmp.cond_dropout example["cond_dropout"] = image_train_tmp.cond_dropout
example["runt_size"] = image_train_tmp.runt_size example["runt_size"] = image_train_tmp.runt_size
example["shuffle_tags"] = image_train_tmp.shuffle_tags example["shuffle_tags"] = image_train_tmp.shuffle_tags
example["loss_scale"] = image_train_tmp.loss_scale
return example return example
@ -214,11 +216,14 @@ def collate_fn(batch):
images = torch.stack(images) images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float() images = images.to(memory_format=torch.contiguous_format).float()
loss_scale = torch.tensor([example.get("loss_scale", 1) for example in batch])
ret = { ret = {
"tokens": torch.stack(tuple(tokens)), "tokens": torch.stack(tuple(tokens)),
"image": images, "image": images,
"captions": captions, "captions": captions,
"runt_size": runt_size, "runt_size": runt_size,
"loss_scale": loss_scale
} }
del batch del batch
return ret return ret

View File

@ -138,10 +138,11 @@ class ImageTrainItem:
aspects: list[float], aspects: list[float],
pathname: str, pathname: str,
flip_p=0.0, flip_p=0.0,
multiplier: float=1.0, multiplier: float=1.0,
cond_dropout=None, cond_dropout=None,
shuffle_tags=False, shuffle_tags=False,
batch_id: str=None batch_id: str=None,
loss_scale: float=None
): ):
self.caption = caption self.caption = caption
self.aspects = aspects self.aspects = aspects
@ -153,6 +154,7 @@ class ImageTrainItem:
self.cond_dropout = cond_dropout self.cond_dropout = cond_dropout
self.shuffle_tags = shuffle_tags self.shuffle_tags = shuffle_tags
self.batch_id = batch_id or DEFAULT_BATCH_ID self.batch_id = batch_id or DEFAULT_BATCH_ID
self.loss_scale = loss_scale or 1
self.target_wh = None self.target_wh = None
self.image_size = None self.image_size = None

View File

@ -1138,8 +1138,12 @@ def main(args):
del target, model_pred del target, model_pred
if batch["runt_size"] > 0: if batch["runt_size"] > 0:
loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5 runt_loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5
loss = loss * loss_scale loss = loss * runt_loss_scale
if "loss_scale" in batch.keys():
# Apply the mask to the loss
loss = loss * batch["loss_scale"]
ed_optimizer.step(loss, step, global_step) ed_optimizer.step(loss, step, global_step)