add missing json, fix error

This commit is contained in:
Damian Stewart 2023-10-22 20:09:53 +02:00
parent 9396d2156e
commit 6434844432
3 changed files with 15 additions and 1 deletions

10
plugins/accumulnator.json Normal file
View File

@ -0,0 +1,10 @@
{
"documentation": {
"curve": "one of 'linear' or 'log', defines how the grad_accum value changes from begin to end"
},
"begin_epoch": 3,
"end_epoch": 10,
"begin_grad_accum": 1,
"end_grad_accum": 100,
"curve": "linear"
}

View File

@ -15,6 +15,10 @@ class Accumulnator(BasePlugin):
begin_grad_accum = config['begin_grad_accum']
end_epoch = config['end_epoch']
end_grad_accum = config['end_grad_accum']
curve = config['curve']
if curve != 'linear':
raise NotImplementedError("Only 'linear' curve is implemented for now")
accums_per_epoch = {}
for i in range(begin_epoch):
accums_per_epoch[i] = begin_grad_accum

View File

@ -159,7 +159,7 @@ def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, s
logging.warning(" No model to save, something likely blew up on startup, not saving")
return
if args.ema_decay_rate != None:
if ed_state.unet_ema is not None or ed_state.text_encoder_ema is not None:
pipeline_ema = StableDiffusionPipeline(
vae=ed_state.vae,
text_encoder=ed_state.text_encoder_ema,