EveryDream2trainer/plugins/accumulnator.py

74 lines
3.2 KiB
Python
Raw Normal View History

import json
import logging
2023-10-22 13:41:20 -06:00
import math
import os
2023-10-22 13:41:20 -06:00
import torch
from plugins.plugins import BasePlugin
class Accumulnator(BasePlugin):
def __init__(self):
path = os.path.join(os.path.dirname(__file__), "accumulnator.json")
2023-10-22 15:00:54 -06:00
logging.info(f" * Accumulnator plugin instantiated, loading config from {path}")
with open(path, 'rt') as f:
config = json.load(f)
begin_epoch = config['begin_epoch']
begin_grad_accum = config['begin_grad_accum']
end_epoch = config['end_epoch']
end_grad_accum = config['end_grad_accum']
2023-10-22 13:41:20 -06:00
# spread the grad accums
2023-10-22 12:09:53 -06:00
curve = config['curve']
2023-10-22 14:02:38 -06:00
steps = end_epoch - begin_epoch
2023-10-22 13:41:20 -06:00
if curve == 'linear':
accums = torch.linspace(start=begin_grad_accum,
end=end_grad_accum,
steps=end_epoch-begin_epoch).tolist()
elif curve == 'log':
accums = torch.logspace(start=math.log(begin_grad_accum, 2),
end=math.log(end_grad_accum, 2),
base=2,
2023-10-22 14:02:38 -06:00
steps=steps).tolist()
2023-10-22 13:41:20 -06:00
else:
raise NotImplementedError(f"curve not {curve} not recognized")
2023-10-22 14:02:38 -06:00
#print(f"accums: {accums}")
accums_per_epoch = {}
for i in range(begin_epoch):
accums_per_epoch[i] = begin_grad_accum
2023-10-22 14:02:38 -06:00
for i in range(steps):
#print(f"took accum {accums[i]} for epoch {i+begin_epoch}")
2023-10-22 13:41:20 -06:00
accums_per_epoch[i+begin_epoch] = round(accums[i])
2023-10-22 15:00:54 -06:00
logging.info(f" * Accumulnator will set grad_accum as follows: {accums_per_epoch}")
self.per_epoch_grad_accum = accums_per_epoch
def on_epoch_end(self, **kwargs):
just_finished_epoch = kwargs['epoch']
epoch = just_finished_epoch + 1
grad_accum = self.per_epoch_grad_accum.get(epoch)
if grad_accum is None:
2023-10-22 15:00:54 -06:00
logging.warning(f" * Accumulnator has no grad_accum setting for epoch {epoch} - leaving as-is")
else:
2023-10-22 15:00:54 -06:00
logging.info(f" * Accumulnator setting grad_accum for epoch {epoch} to {grad_accum}")
arg_update_callback = kwargs['arg_update_callback']
arg_update_callback('grad_accum', grad_accum)
def _get_update_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
if self.every_n_epochs >= 1:
if ((epoch+1) % self.every_n_epochs) == 0:
# last step only
return [epoch_length_steps-1]
else:
return []
else:
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
# validation happens after training:
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
validate_every_n_steps = epoch_length_steps / num_divisions
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]