diff --git a/Train_Colab.ipynb b/Train_Colab.ipynb index 6b47819..495035b 100644 --- a/Train_Colab.ipynb +++ b/Train_Colab.ipynb @@ -68,7 +68,7 @@ "outputs": [], "source": [ "#@title Optional connect Gdrive\n", - "#@markdown # but strongly recommended\n", + "#@markdown # But strongly recommended\n", "#@markdown This will let you put all your training data and checkpoints directly on your drive. Much faster/easier to continue later, less setup time.\n", "\n", "#@markdown Creates /content/drive/MyDrive/everydreamlogs/ckpt\n", @@ -82,8 +82,8 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", - "id": "hAuBbtSvGpau" + "id": "hAuBbtSvGpau", + "cellView": "form" }, "outputs": [], "source": [ @@ -94,7 +94,7 @@ "s = getoutput('nvidia-smi')\n", "!pip install -q torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url \"https://download.pytorch.org/whl/cu117\"\n", "!pip install -q transformers==4.25.1\n", - "!pip install -q diffusers[torch]==0.10.2\n", + "!pip install -q diffusers[torch]==0.13.0\n", "!pip install -q pynvml==11.4.1\n", "!pip install -q bitsandbytes==0.35.0\n", "!pip install -q ftfy==6.1.1\n", @@ -329,7 +329,12 @@ "#@markdown * Using the same seed each time you train allows for more accurate a/b comparison of models, leave at -1 for random\n", "#@markdown * The seed also effects your training samples, if you want the same seed each sample you will need to change it from -1\n", "Training_Seed = -1 #@param{type:\"integer\"}\n", - "\n", + "#@markdown * use this option to configure a sample_prompts.json\n", + "#@markdown * check out /content/EveryDream2trainer/doc/logging.md. for more details\n", + "Advance_Samples = False #@param{type:\"boolean\"}\n", + "Sample_File = \"sample_prompts.txt\"\n", + "if Advance_Samples:\n", + " Sample_File = \"sample_prompts.json\"\n", "#@markdown * Checkpointing Saves Vram to allow larger batch sizes minor slow down on a single batch size but will can allow room for a higher traning resolution (suggested on Colab Free tier, turn off for A100)\n", "Gradient_checkpointing = True #@param{type:\"boolean\"}\n", "Disable_Xformers = False #@param{type:\"boolean\"}\n", @@ -405,7 +410,7 @@ " --max_epochs $Max_Epochs \\\n", " --project_name \"$Project_Name\" \\\n", " --resolution $Resolution \\\n", - " --sample_prompts \"sample_prompts.txt\" \\\n", + " --sample_prompts \"$Sample_File\" \\\n", " --sample_steps $Steps_between_samples \\\n", " --save_every_n_epoch $Save_every_N_epoch \\\n", " --seed $Training_Seed \\\n", @@ -501,4 +506,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/data/every_dream_validation.py b/data/every_dream_validation.py index 4f2c083..2f92275 100644 --- a/data/every_dream_validation.py +++ b/data/every_dream_validation.py @@ -23,7 +23,7 @@ from utils.isolate_rng import isolate_rng def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \ -> tuple[list[ImageTrainItem], list[ImageTrainItem]]: - split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size + split_item_count = math.ceil(split_proportion * len(items) / batch_size) * batch_size # sort first, then shuffle, to ensure determinate outcome for the current random state items_copy = list(sorted(items, key=lambda i: i.pathname)) random.shuffle(items_copy) diff --git a/train.py b/train.py index 142e7e3..29553c4 100644 --- a/train.py +++ b/train.py @@ -711,6 +711,10 @@ def main(args): return model_pred, target + # Pre-train validation to establish a starting point on the loss graph + if validator: + validator.do_validation_if_appropriate(epoch=0, global_step=0, + get_model_prediction_and_target_callable=get_model_prediction_and_target) try: # # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py @@ -849,7 +853,7 @@ def main(args): log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step) if validator: - validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target) + validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target) gc.collect() # end of epoch diff --git a/utils/split_dataset.py b/utils/split_dataset.py index 93a1064..c16f235 100644 --- a/utils/split_dataset.py +++ b/utils/split_dataset.py @@ -8,16 +8,21 @@ from typing import Optional from tqdm.auto import tqdm IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif'] - +CAPTION_EXTENSIONS = ['.txt', '.caption', '.yaml', '.yml'] def gather_captioned_images(root_dir: str) -> list[tuple[str,Optional[str]]]: for directory, _, filenames in os.walk(root_dir): image_filenames = [f for f in filenames if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS] for image_filename in image_filenames: - caption_filename = os.path.splitext(image_filename)[0] + '.txt' - image_path = os.path.join(directory+image_filename) - caption_path = os.path.join(directory+caption_filename) - yield (image_path, caption_path if os.path.exists(caption_path) else None) + image_path = os.path.join(directory, image_filename) + image_path_without_extension = os.path.splitext(image_path)[0] + caption_path = None + for caption_extension in CAPTION_EXTENSIONS: + possible_caption_path = image_path_without_extension + caption_extension + if os.path.exists(possible_caption_path): + caption_path = possible_caption_path + break + yield image_path, caption_path def copy_captioned_image(image_caption_pair: tuple[str, Optional[str]], source_root: str, target_root: str): @@ -39,13 +44,13 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('source_root', type=str, help='Source root folder containing images') - parser.add_argument('--train_output_folder', type=str, required=True, help="Output folder for the 'train' dataset") - parser.add_argument('--val_output_folder', type=str, required=True, help="Output folder for the 'val' dataset") + parser.add_argument('--train_output_folder', type=str, required=False, help="Output folder for the 'train' dataset. If omitted, do not save the train split.") + parser.add_argument('--val_output_folder', type=str, required=True, help="Output folder for the 'val' dataset.") parser.add_argument('--split_proportion', type=float, required=True, help="Proportion of images to use for 'val' (a number between 0 and 1)") parser.add_argument('--seed', type=int, required=False, default=555, help='Random seed for shuffling') args = parser.parse_args() - images = gather_captioned_images(args.source_root) + images = list(gather_captioned_images(args.source_root)) print(f"Found {len(images)} captioned images in {args.source_root}") val_split_count = math.ceil(len(images) * args.split_proportion) if val_split_count == 0: @@ -59,7 +64,9 @@ if __name__ == '__main__': print(f"Copying 'val' set to {args.val_output_folder}...") for v in tqdm(val_split): copy_captioned_image(v, args.source_root, args.val_output_folder) - print(f"Copying 'train' set to {args.train_output_folder}...") - for v in tqdm(train_split): - copy_captioned_image(v, args.source_root, args.train_output_folder) + + if args.train_output_folder is not None: + print(f"Copying 'train' set to {args.train_output_folder}...") + for v in tqdm(train_split): + copy_captioned_image(v, args.source_root, args.train_output_folder) print("Done.") \ No newline at end of file