Merge pull request #235 from damian0815/feat_negative_loss
add loss_scale.txt
This commit is contained in:
commit
30b063dfec
|
@ -54,6 +54,7 @@ class ImageConfig:
|
||||||
cond_dropout: float = None
|
cond_dropout: float = None
|
||||||
flip_p: float = None
|
flip_p: float = None
|
||||||
shuffle_tags: bool = False
|
shuffle_tags: bool = False
|
||||||
|
loss_scale: float = None
|
||||||
|
|
||||||
def merge(self, other):
|
def merge(self, other):
|
||||||
if other is None:
|
if other is None:
|
||||||
|
@ -68,7 +69,8 @@ class ImageConfig:
|
||||||
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
|
||||||
flip_p=overlay(other.flip_p, self.flip_p),
|
flip_p=overlay(other.flip_p, self.flip_p),
|
||||||
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
|
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
|
||||||
batch_id=overlay(other.batch_id, self.batch_id)
|
batch_id=overlay(other.batch_id, self.batch_id),
|
||||||
|
loss_scale=overlay(other.loss_scale, self.loss_scale)
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -83,7 +85,8 @@ class ImageConfig:
|
||||||
cond_dropout=data.get("cond_dropout"),
|
cond_dropout=data.get("cond_dropout"),
|
||||||
flip_p=data.get("flip_p"),
|
flip_p=data.get("flip_p"),
|
||||||
shuffle_tags=data.get("shuffle_tags"),
|
shuffle_tags=data.get("shuffle_tags"),
|
||||||
batch_id=data.get("batch_id")
|
batch_id=data.get("batch_id"),
|
||||||
|
loss_scale=data.get("loss_scale")
|
||||||
)
|
)
|
||||||
|
|
||||||
# Alternatively parse from dedicated `caption` attribute
|
# Alternatively parse from dedicated `caption` attribute
|
||||||
|
@ -170,6 +173,8 @@ class Dataset:
|
||||||
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
|
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
|
||||||
if 'batch_id.txt' in fileset:
|
if 'batch_id.txt' in fileset:
|
||||||
cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt'])))
|
cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt'])))
|
||||||
|
if 'loss_scale.txt' in fileset:
|
||||||
|
cfgs.append(ImageConfig(loss_scale=read_float(fileset['loss_scale.txt'])))
|
||||||
|
|
||||||
result = ImageConfig.fold(cfgs)
|
result = ImageConfig.fold(cfgs)
|
||||||
if 'shuffle_tags.txt' in fileset:
|
if 'shuffle_tags.txt' in fileset:
|
||||||
|
@ -264,7 +269,8 @@ class Dataset:
|
||||||
multiplier=config.multiply or 1.0,
|
multiplier=config.multiply or 1.0,
|
||||||
cond_dropout=config.cond_dropout,
|
cond_dropout=config.cond_dropout,
|
||||||
shuffle_tags=config.shuffle_tags,
|
shuffle_tags=config.shuffle_tags,
|
||||||
batch_id=config.batch_id
|
batch_id=config.batch_id,
|
||||||
|
loss_scale=config.loss_scale
|
||||||
)
|
)
|
||||||
items.append(item)
|
items.append(item)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -118,6 +118,7 @@ class EveryDreamBatch(Dataset):
|
||||||
example["tokens"] = torch.tensor(example["tokens"])
|
example["tokens"] = torch.tensor(example["tokens"])
|
||||||
|
|
||||||
example["runt_size"] = train_item["runt_size"]
|
example["runt_size"] = train_item["runt_size"]
|
||||||
|
example["loss_scale"] = train_item["loss_scale"]
|
||||||
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
@ -134,6 +135,7 @@ class EveryDreamBatch(Dataset):
|
||||||
example["cond_dropout"] = image_train_tmp.cond_dropout
|
example["cond_dropout"] = image_train_tmp.cond_dropout
|
||||||
example["runt_size"] = image_train_tmp.runt_size
|
example["runt_size"] = image_train_tmp.runt_size
|
||||||
example["shuffle_tags"] = image_train_tmp.shuffle_tags
|
example["shuffle_tags"] = image_train_tmp.shuffle_tags
|
||||||
|
example["loss_scale"] = image_train_tmp.loss_scale
|
||||||
|
|
||||||
return example
|
return example
|
||||||
|
|
||||||
|
@ -214,11 +216,14 @@ def collate_fn(batch):
|
||||||
images = torch.stack(images)
|
images = torch.stack(images)
|
||||||
images = images.to(memory_format=torch.contiguous_format).float()
|
images = images.to(memory_format=torch.contiguous_format).float()
|
||||||
|
|
||||||
|
loss_scale = torch.tensor([example.get("loss_scale", 1) for example in batch])
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
"tokens": torch.stack(tuple(tokens)),
|
"tokens": torch.stack(tuple(tokens)),
|
||||||
"image": images,
|
"image": images,
|
||||||
"captions": captions,
|
"captions": captions,
|
||||||
"runt_size": runt_size,
|
"runt_size": runt_size,
|
||||||
|
"loss_scale": loss_scale
|
||||||
}
|
}
|
||||||
del batch
|
del batch
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -141,7 +141,8 @@ class ImageTrainItem:
|
||||||
multiplier: float=1.0,
|
multiplier: float=1.0,
|
||||||
cond_dropout=None,
|
cond_dropout=None,
|
||||||
shuffle_tags=False,
|
shuffle_tags=False,
|
||||||
batch_id: str=None
|
batch_id: str=None,
|
||||||
|
loss_scale: float=None
|
||||||
):
|
):
|
||||||
self.caption = caption
|
self.caption = caption
|
||||||
self.aspects = aspects
|
self.aspects = aspects
|
||||||
|
@ -153,6 +154,7 @@ class ImageTrainItem:
|
||||||
self.cond_dropout = cond_dropout
|
self.cond_dropout = cond_dropout
|
||||||
self.shuffle_tags = shuffle_tags
|
self.shuffle_tags = shuffle_tags
|
||||||
self.batch_id = batch_id or DEFAULT_BATCH_ID
|
self.batch_id = batch_id or DEFAULT_BATCH_ID
|
||||||
|
self.loss_scale = 1 if loss_scale is None else loss_scale
|
||||||
self.target_wh = None
|
self.target_wh = None
|
||||||
|
|
||||||
self.image_size = None
|
self.image_size = None
|
||||||
|
|
22
train.py
22
train.py
|
@ -901,7 +901,7 @@ def main(args):
|
||||||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||||
|
|
||||||
# actual prediction function - shared between train and validate
|
# actual prediction function - shared between train and validate
|
||||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False):
|
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with autocast(enabled=args.amp):
|
with autocast(enabled=args.amp):
|
||||||
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||||
|
@ -947,10 +947,10 @@ def main(args):
|
||||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
if return_loss:
|
if return_loss:
|
||||||
if args.min_snr_gamma is None:
|
if loss_scale is None:
|
||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
loss_scale = torch.ones(model_pred.shape[0], dtype=torch.float)
|
||||||
|
|
||||||
else:
|
if args.min_snr_gamma is not None:
|
||||||
snr = compute_snr(timesteps, noise_scheduler)
|
snr = compute_snr(timesteps, noise_scheduler)
|
||||||
|
|
||||||
mse_loss_weights = (
|
mse_loss_weights = (
|
||||||
|
@ -960,8 +960,10 @@ def main(args):
|
||||||
/ snr
|
/ snr
|
||||||
)
|
)
|
||||||
mse_loss_weights[snr == 0] = 1.0
|
mse_loss_weights[snr == 0] = 1.0
|
||||||
|
loss_scale = loss_scale * mse_loss_weights
|
||||||
|
|
||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * loss_scale.to(unet.device)
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
return model_pred, target, loss
|
return model_pred, target, loss
|
||||||
|
@ -1133,13 +1135,17 @@ def main(args):
|
||||||
batch=batch,
|
batch=batch,
|
||||||
ed_state=make_current_ed_state())
|
ed_state=make_current_ed_state())
|
||||||
|
|
||||||
model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True)
|
model_pred, target, loss = get_model_prediction_and_target(batch["image"],
|
||||||
|
batch["tokens"],
|
||||||
|
args.zero_frequency_noise_ratio,
|
||||||
|
return_loss=True,
|
||||||
|
loss_scale=batch["loss_scale"])
|
||||||
|
|
||||||
del target, model_pred
|
del target, model_pred
|
||||||
|
|
||||||
if batch["runt_size"] > 0:
|
if batch["runt_size"] > 0:
|
||||||
loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5
|
runt_loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5
|
||||||
loss = loss * loss_scale
|
loss = loss * runt_loss_scale
|
||||||
|
|
||||||
ed_optimizer.step(loss, step, global_step)
|
ed_optimizer.step(loss, step, global_step)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue