fix issue with plugin interaction with validation
This commit is contained in:
parent
9141710d6e
commit
e89ec053c7
|
@ -19,6 +19,7 @@ from data.data_loader import DataLoaderMultiAspect
|
|||
from data import resolver
|
||||
from data import aspects
|
||||
from data.image_train_item import ImageTrainItem
|
||||
from plugins.plugins import PluginRunner
|
||||
from utils.isolate_rng import isolate_rng
|
||||
|
||||
from colorama import Fore, Style
|
||||
|
@ -276,6 +277,8 @@ class EveryDreamValidator:
|
|||
batch_size=batch_size,
|
||||
seed=seed,
|
||||
)
|
||||
empty_plugin_runner = PluginRunner()
|
||||
|
||||
ed_batch = EveryDreamBatch(
|
||||
data_loader=data_loader,
|
||||
debug_level=1,
|
||||
|
@ -283,6 +286,7 @@ class EveryDreamValidator:
|
|||
tokenizer=tokenizer,
|
||||
seed=seed,
|
||||
name=name,
|
||||
crop_jitter=0
|
||||
crop_jitter=0,
|
||||
empty_plugin_runner=empty_plugin_runner,
|
||||
)
|
||||
return ed_batch
|
||||
|
|
|
@ -51,7 +51,7 @@ class Timer:
|
|||
|
||||
|
||||
class PluginRunner:
|
||||
def __init__(self, plugins: list, epoch_warn_seconds=5, step_warn_seconds=0.5, training_warn_seconds=20):
|
||||
def __init__(self, plugins:list=[], epoch_warn_seconds=5, step_warn_seconds=0.5, training_warn_seconds=20):
|
||||
"""
|
||||
plugins: list of plugins to run
|
||||
epoch_warn_seconds: warn if any epoch start/end call takes longer than this
|
||||
|
|
Loading…
Reference in New Issue