added some resolutions, option for val-loss pos-neg, fix wandb
This commit is contained in:
parent
3744bc0dc9
commit
35d52b56e0
|
@ -13,6 +13,48 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
Notes:
|
||||
this is generated from an excel sheet and actual ratios are hand picked to
|
||||
spread out the ratios evenly to avoid having super-finely defined buckets
|
||||
Too many buckets means more "runt" steps with repeated images to fill batch than necessary
|
||||
ex. we do not need both 1.0:1 and 1.125:1, they're almost identical ratios
|
||||
Try to keep around <20 ratio buckets per resolution, should be plenty coverage everything between 1:1 and 4:1
|
||||
More finely defined buckets will reduce cropping at the expense of more runt steps
|
||||
"""
|
||||
|
||||
ASPECTS_1536 = [[1536,1536], # 2359296 1:1
|
||||
[1728,1344],[1344,1728], # 2322432 1.286:1
|
||||
[1792,1280],[1280,1792], # 2293760 1.4:1
|
||||
[2048,1152],[1152,2048], # 2359296 1.778:1
|
||||
[2304,1024],[1024,2304], # 2359296 2.25:1
|
||||
[2432,960],[960,2432], # 2334720 2.53:1
|
||||
[2624,896],[896,2624], # 2351104 2.929:1
|
||||
[2816,832],[832,2816], # 2342912 3.385:1
|
||||
[3072,768],[768,3072], # 2359296 4:1
|
||||
]
|
||||
|
||||
ASPECTS_1408 = [[1408,1408], # 1982464 1:1
|
||||
[1536,1280],[1280,1536], # 1966080 1.2:1
|
||||
[1664,1152],[1152,1664], # 1916928 1.444:1
|
||||
[1920,1024],[1024,1920], # 1966080 1.875:1
|
||||
[2048,960],[960,2048], # 1966080 2.133:1
|
||||
[2368,832],[832,2368], # 1970176 2.846:1
|
||||
[2560,768],[768,2560], # 1966080 3.333:1
|
||||
[2816,704],[704,3072], # 1982464 4:1
|
||||
]
|
||||
|
||||
ASPECTS_1280 = [[1280,1280], # 1638400 1:1
|
||||
[1408,1152],[1408,1344], # 1622016 1.222:1
|
||||
[1600,1024],[1024,1600], # 1638400 1.563:1
|
||||
[1792,896],[896,1792], # 1605632 2:1
|
||||
[1920,832],[832,1920], # 1597440 2.308:1
|
||||
[2112,768],[768,2112], # 1585152 2.75:1
|
||||
[2304,704],[704,2304], # 1622016 3.27:1
|
||||
[2560,640],[640,2560], # 1638400 4:1
|
||||
]
|
||||
|
||||
ASPECTS_1152 = [[1152,1152], # 1327104 1:1
|
||||
#[1216,1088],[1088,1216], # 1323008 1.118:1
|
||||
[1280,1024],[1024,1280], # 1310720 1.25:1
|
||||
|
@ -48,7 +90,7 @@ ASPECTS_1024 = [[1024,1024], # 1048576 1:1
|
|||
]
|
||||
|
||||
ASPECTS_960 = [[960,960], # 921600 1:1
|
||||
[1024,896],[896,1024], # 917504 1.143:1
|
||||
#[1024,896],[896,1024], # 917504 1.143:1
|
||||
[1088,832],[832,1088], # 905216 1.308:1
|
||||
[1152,768],[768,1152], # 884736 1.5:1
|
||||
[1280,704],[704,1280], # 901120 1.818:1
|
||||
|
@ -56,11 +98,11 @@ ASPECTS_960 = [[960,960], # 921600 1:1
|
|||
[1680,576],[576,1680], # 921600 2.778:1
|
||||
#[1728,512],[512,1728], # 884736 3.375:1
|
||||
[1792,512],[512,1792], # 917504 3.5:1
|
||||
[2048,448],[448,2048], # 917504 4.714:1
|
||||
[2048,448],[448,2048], # 917504 4.57:1
|
||||
]
|
||||
|
||||
ASPECTS_896 = [[896,896], # 802816 1:1
|
||||
[960,832],[832,960], # 798720 1.153:1
|
||||
#[960,832],[832,960], # 798720 1.153:1
|
||||
[1024,768],[768,1024], # 786432 1.333:1
|
||||
[1088,704],[704,1088], # 765952 1.545:1
|
||||
[1216,640],[640,1216], # 778240 1.9:1
|
||||
|
@ -155,7 +197,7 @@ ASPECTS_384 = [[384,384], # 147456 1:1
|
|||
ASPECTS_256 = [[256,256], # 65536 1:1
|
||||
[384,192],[192,384], # 73728 2:1
|
||||
[512,128],[128,512], # 65536 4:1
|
||||
]
|
||||
] # very few buckets available for 256 with 64 pixel increments
|
||||
|
||||
def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False):
|
||||
if resolution < 256:
|
||||
|
@ -173,6 +215,10 @@ def get_aspect_buckets(resolution, square_only=False, reduced_buckets=False):
|
|||
print(f" *** Unsupported resolution of {resolution}, check your resolution config")
|
||||
print(f" *** Value must be between 512 and 1024")
|
||||
raise e
|
||||
|
||||
def get_supported_resolutions():
|
||||
all_image_sizes = __get_all_aspects()
|
||||
return list(map(lambda sizes: sizes[0][0], all_image_sizes))
|
||||
|
||||
def __get_all_aspects():
|
||||
return [ASPECTS_256,
|
||||
|
@ -188,5 +234,7 @@ def __get_all_aspects():
|
|||
ASPECTS_960,
|
||||
ASPECTS_1024,
|
||||
ASPECTS_1088,
|
||||
ASPECTS_1152
|
||||
ASPECTS_1152,
|
||||
ASPECTS_1280,
|
||||
ASPECTS_1536,
|
||||
]
|
|
@ -40,12 +40,12 @@ class EveryDreamValidator:
|
|||
val_config_path: Optional[str],
|
||||
default_batch_size: int,
|
||||
resolution: int,
|
||||
log_writer: SummaryWriter):
|
||||
log_writer: SummaryWriter,
|
||||
):
|
||||
self.val_dataloader = None
|
||||
self.train_overlapping_dataloader = None
|
||||
|
||||
self.log_writer = log_writer
|
||||
self.resolution = resolution
|
||||
self.log_writer = log_writer
|
||||
|
||||
self.config = {
|
||||
'batch_size': default_batch_size,
|
||||
|
@ -57,7 +57,9 @@ class EveryDreamValidator:
|
|||
'val_split_proportion': 0.15,
|
||||
|
||||
'stabilize_training_loss': False,
|
||||
'stabilize_split_proportion': 0.15
|
||||
'stabilize_split_proportion': 0.15,
|
||||
|
||||
'use_relative_loss': False,
|
||||
}
|
||||
if val_config_path is not None:
|
||||
with open(val_config_path, 'rt') as f:
|
||||
|
@ -67,7 +69,7 @@ class EveryDreamValidator:
|
|||
self.val_loss_offset = None
|
||||
|
||||
self.loss_val_history = []
|
||||
self.val_loss_window_size = 4 # todo: arg for this?
|
||||
self.val_loss_window_size = 5 # todo: arg for this?
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
|
@ -80,6 +82,10 @@ class EveryDreamValidator:
|
|||
@property
|
||||
def seed(self):
|
||||
return self.config['seed']
|
||||
|
||||
@property
|
||||
def use_relative_loss(self):
|
||||
return self.config['use_relative_loss']
|
||||
|
||||
def prepare_validation_splits(self, train_items: list[ImageTrainItem], tokenizer: Any) -> list[ImageTrainItem]:
|
||||
"""
|
||||
|
@ -117,14 +123,19 @@ class EveryDreamValidator:
|
|||
if self.val_loss_offset is None:
|
||||
self.val_loss_offset = -mean_loss
|
||||
self.log_writer.add_scalar(tag=f"loss/val",
|
||||
scalar_value=self.val_loss_offset + mean_loss,
|
||||
scalar_value=mean_loss if not self.use_relative_loss else self.val_loss_offset + mean_loss,
|
||||
global_step=global_step)
|
||||
self.loss_val_history.append(mean_loss)
|
||||
if len(self.loss_val_history) > (self.val_loss_window_size * 2 + 1):
|
||||
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
|
||||
if np.average(dy) > 0:
|
||||
logging.warning(f"Validation loss shows diverging")
|
||||
# todo: signal stop?
|
||||
|
||||
|
||||
self.track_loss_trend(mean_loss)
|
||||
|
||||
def track_loss_trend(self, mean_loss):
|
||||
self.loss_val_history.append(mean_loss)
|
||||
|
||||
if len(self.loss_val_history) > ((self.val_loss_window_size * 2) + 1):
|
||||
dy = np.diff(self.loss_val_history[-self.val_loss_window_size:])
|
||||
if np.average(dy) > 0:
|
||||
logging.warning(f"Validation loss shows diverging. Check your val/loss graph.")
|
||||
|
||||
def _calculate_validation_loss(self, tag, dataloader, get_model_prediction_and_target: Callable[
|
||||
[Any, Any], tuple[torch.Tensor, torch.Tensor]]) -> float:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
aiohttp==3.8.4
|
||||
colorama==0.4.6
|
||||
diffusers[torch]>=0.13.0
|
||||
diffusers[torch]>=0.14.0
|
||||
ftfy==6.1.1
|
||||
ipyevents
|
||||
ipywidgets
|
||||
|
@ -14,7 +14,7 @@ pyfakefs
|
|||
pynvml==11.5.0
|
||||
pyre-extensions==0.0.30
|
||||
pytorch-lightning==1.9.2
|
||||
tensorboard==2.11.0
|
||||
tensorboard==2.12.0
|
||||
transformers==4.25.1
|
||||
triton>=2.0.0a2
|
||||
wandb
|
|
@ -13,5 +13,5 @@
|
|||
"betas": [0.9, 0.999],
|
||||
"epsilon": 1e-8,
|
||||
"weight_decay": 0.010,
|
||||
"text_encoder_lr_scale": 1.0
|
||||
"text_encoder_lr_scale": 0.50
|
||||
}
|
||||
|
|
39
train.py
39
train.py
|
@ -49,6 +49,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
|||
from accelerate.utils import set_seed
|
||||
|
||||
import wandb
|
||||
import webbrowser
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from data.data_loader import DataLoaderMultiAspect
|
||||
|
||||
|
@ -120,8 +121,12 @@ def setup_local_logger(args):
|
|||
format="%(asctime)s %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
|
||||
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.addFilter(lambda msg: "Palette images with Transparency expressed in bytes" in msg.getMessage())
|
||||
logging.getLogger().addHandler(console_handler)
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore", message="UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images")
|
||||
#from PIL import Image
|
||||
|
||||
return datetimestamp
|
||||
|
||||
|
@ -138,7 +143,7 @@ def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
|||
"""
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def load_optimizer(optimizer, path: str):
|
||||
def load_optimizer(optimizer: torch.optim.Optimizer, path: str):
|
||||
"""
|
||||
Loads the optimizer state
|
||||
"""
|
||||
|
@ -491,17 +496,24 @@ def main(args):
|
|||
with open(os.path.join(os.curdir, optimizer_config_path), "r") as f:
|
||||
optimizer_config = json.load(f)
|
||||
|
||||
if args.wandb is not None and args.wandb:
|
||||
wandb.init(project=args.project_name,
|
||||
sync_tensorboard=True,
|
||||
dir=args.logdir,
|
||||
config=args,
|
||||
name=args.run_name,
|
||||
)
|
||||
if args.wandb:
|
||||
wandb.tensorboard.patch(root_logdir=log_folder, pytorch=False, tensorboard_x=False, save=False)
|
||||
wandb_run = wandb.init(
|
||||
project=args.project_name,
|
||||
config={"main_cfg": vars(args), "optimizer_cfg": optimizer_config},
|
||||
name=args.run_name,
|
||||
#sync_tensorboard=True, # broken?
|
||||
#dir=log_folder, # only for save, just duplicates the TB log to /{log_folder}/wandb ...
|
||||
)
|
||||
try:
|
||||
if webbrowser.get():
|
||||
webbrowser.open(wandb_run.url, new=2)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
log_writer = SummaryWriter(log_dir=log_folder,
|
||||
flush_secs=10,
|
||||
comment=args.run_name if args.run_name is not None else "EveryDream2FineTunes",
|
||||
flush_secs=20,
|
||||
comment=args.run_name if args.run_name is not None else log_time,
|
||||
)
|
||||
|
||||
betas = [0.9, 0.999]
|
||||
|
@ -929,8 +941,7 @@ def main(args):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
supported_resolutions = [256, 384, 448, 512, 576, 640, 704, 768, 832, 896, 960, 1024, 1088, 1152]
|
||||
supported_precisions = ['fp16', 'fp32']
|
||||
supported_resolutions = aspects.get_supported_resolutions()
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||
args, argv = argparser.parse_known_args()
|
||||
|
|
|
@ -7,7 +7,8 @@
|
|||
"stabilize_training_loss": "If true, stabilize the train loss curves for `loss/epoch` and `loss/log step` by re-calculating training loss with a fixed random seed, and log the results as `loss/train-stabilized`. This more clearly shows the training progress, but it is not enough alone to tell you if you're overfitting.",
|
||||
"stabilize_split_proportion": "For stabilize_training_loss, the proportion of the train dataset to overlap for stabilizing the train loss graph. Typical values are 0.15-0.2 (15-20% of the total dataset). Higher is more accurate but slower.",
|
||||
"every_n_epochs": "How often to run validation (1=every epoch).",
|
||||
"seed": "The seed to use when running validation and stabilization passes."
|
||||
"seed": "The seed to use when running validation and stabilization passes.",
|
||||
"use_relative_loss": "logs val/loss as negative relative to first pre-train val/loss value"
|
||||
},
|
||||
"validate_training": true,
|
||||
"val_split_mode": "automatic",
|
||||
|
@ -16,5 +17,6 @@
|
|||
"stabilize_training_loss": false,
|
||||
"stabilize_split_proportion": 0.15,
|
||||
"every_n_epochs": 1,
|
||||
"seed": 555
|
||||
"seed": 555,
|
||||
"use_relative_loss": false
|
||||
}
|
|
@ -4,7 +4,7 @@ echo should be in venv here
|
|||
cd .
|
||||
python -m pip install --upgrade pip
|
||||
pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url "https://download.pytorch.org/whl/cu116"
|
||||
pip install transformers==4.25.1
|
||||
pip install transformers==4.27.1
|
||||
pip install diffusers[torch]==0.13.0
|
||||
pip install pynvml==11.4.1
|
||||
pip install bitsandbytes==0.35.0
|
||||
|
@ -13,7 +13,7 @@ pip install ftfy==6.1.1
|
|||
pip install aiohttp==3.8.3
|
||||
pip install tensorboard>=2.11.0
|
||||
pip install protobuf==3.20.1
|
||||
pip install wandb==0.13.6
|
||||
pip install wandb==0.14.0
|
||||
pip install pyre-extensions==0.0.23
|
||||
pip install -U -I --no-deps https://github.com/C43H66N12O12S2/stable-diffusion-webui/releases/download/f/xformers-0.0.14.dev0-cp310-cp310-win_amd64.whl
|
||||
::pip install "xformers-0.0.15.dev0+affe4da.d20221212-cp38-cp38-win_amd64.whl" --force-reinstall
|
||||
|
|
Loading…
Reference in New Issue