statistics for pbar
This commit is contained in:
parent
40b56c9289
commit
348f89c8d4
|
@ -335,6 +335,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
size = len(ds.indexes)
|
size = len(ds.indexes)
|
||||||
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
loss_dict = defaultdict(lambda : deque(maxlen = 1024))
|
||||||
losses = torch.zeros((size,))
|
losses = torch.zeros((size,))
|
||||||
|
previous_mean_losses = [0]
|
||||||
previous_mean_loss = 0
|
previous_mean_loss = 0
|
||||||
print("Mean loss of {} elements".format(size))
|
print("Mean loss of {} elements".format(size))
|
||||||
|
|
||||||
|
@ -356,7 +357,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
for i, entries in pbar:
|
for i, entries in pbar:
|
||||||
hypernetwork.step = i + ititial_step
|
hypernetwork.step = i + ititial_step
|
||||||
if len(loss_dict) > 0:
|
if len(loss_dict) > 0:
|
||||||
previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
|
previous_mean_losses = [i[-1] for i in loss_dict.values()]
|
||||||
|
previous_mean_loss = mean(previous_mean_losses)
|
||||||
|
|
||||||
scheduler.apply(optimizer, hypernetwork.step)
|
scheduler.apply(optimizer, hypernetwork.step)
|
||||||
if scheduler.finished:
|
if scheduler.finished:
|
||||||
|
@ -391,7 +393,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
|
||||||
|
|
||||||
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
|
||||||
raise RuntimeError("Loss diverged.")
|
raise RuntimeError("Loss diverged.")
|
||||||
pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
|
|
||||||
|
if len(previous_mean_losses) > 1:
|
||||||
|
std = stdev(previous_mean_losses)
|
||||||
|
else:
|
||||||
|
std = 0
|
||||||
|
dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
|
||||||
|
pbar.set_description(dataset_loss_info)
|
||||||
|
|
||||||
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
|
||||||
# Before saving, change name to match current checkpoint.
|
# Before saving, change name to match current checkpoint.
|
||||||
|
|
Loading…
Reference in New Issue