add pbar back to preloading, remove cruft from testing loss stuff
This commit is contained in:
parent
d1bc94fe3e
commit
ba687de8b4
|
@ -9,6 +9,8 @@ from data.image_train_item import ImageCaption, ImageTrainItem
|
||||||
from utils.fs_helpers import *
|
from utils.fs_helpers import *
|
||||||
from typing import TypeVar, Iterable
|
from typing import TypeVar, Iterable
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def overlay(overlay, base):
|
def overlay(overlay, base):
|
||||||
return overlay if overlay is not None else base
|
return overlay if overlay is not None else base
|
||||||
|
@ -216,7 +218,7 @@ class Dataset:
|
||||||
|
|
||||||
def image_train_items(self, aspects):
|
def image_train_items(self, aspects):
|
||||||
items = []
|
items = []
|
||||||
for image in self.image_configs:
|
for image in tqdm(self.image_configs, desc="preloading", dynamic_ncols=True):
|
||||||
config = self.image_configs[image]
|
config = self.image_configs[image]
|
||||||
if len(config.main_prompts) > 1:
|
if len(config.main_prompts) > 1:
|
||||||
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
|
logging.warning(f" *** Found multiple multiple main_prompts for image {image}, but only one will be applied: {config.main_prompts}")
|
||||||
|
|
6
train.py
6
train.py
|
@ -811,12 +811,6 @@ def main(args):
|
||||||
|
|
||||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
|
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
|
||||||
|
|
||||||
# with torch.no_grad():
|
|
||||||
# loss_l1 = F.l1_loss(model_pred.float(), target.float(), reduction="mean")
|
|
||||||
# log_writer.add_scalar(tag="loss/l1", scalar_value=loss_l1, global_step=global_step)
|
|
||||||
# loss_hinge = F.hinge_embedding_loss(model_pred.float(), target.float(), reduction="mean")
|
|
||||||
# log_writer.add_scalar(tag="loss/hinge", scalar_value=loss_hinge, global_step=global_step)
|
|
||||||
|
|
||||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
del target, model_pred
|
del target, model_pred
|
||||||
|
|
Loading…
Reference in New Issue