Merge branch 'main' of https://github.com/victorchall/EveryDream2trainer
This commit is contained in:
commit
a2479cfe1f
|
@ -68,7 +68,7 @@
|
|||
"outputs": [],
|
||||
"source": [
|
||||
"#@title Optional connect Gdrive\n",
|
||||
"#@markdown # but strongly recommended\n",
|
||||
"#@markdown # But strongly recommended\n",
|
||||
"#@markdown This will let you put all your training data and checkpoints directly on your drive. Much faster/easier to continue later, less setup time.\n",
|
||||
"\n",
|
||||
"#@markdown Creates /content/drive/MyDrive/everydreamlogs/ckpt\n",
|
||||
|
@ -82,8 +82,8 @@
|
|||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "hAuBbtSvGpau"
|
||||
"id": "hAuBbtSvGpau",
|
||||
"cellView": "form"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
|
@ -94,7 +94,7 @@
|
|||
"s = getoutput('nvidia-smi')\n",
|
||||
"!pip install -q torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url \"https://download.pytorch.org/whl/cu117\"\n",
|
||||
"!pip install -q transformers==4.25.1\n",
|
||||
"!pip install -q diffusers[torch]==0.10.2\n",
|
||||
"!pip install -q diffusers[torch]==0.13.0\n",
|
||||
"!pip install -q pynvml==11.4.1\n",
|
||||
"!pip install -q bitsandbytes==0.35.0\n",
|
||||
"!pip install -q ftfy==6.1.1\n",
|
||||
|
@ -329,7 +329,12 @@
|
|||
"#@markdown * Using the same seed each time you train allows for more accurate a/b comparison of models, leave at -1 for random\n",
|
||||
"#@markdown * The seed also effects your training samples, if you want the same seed each sample you will need to change it from -1\n",
|
||||
"Training_Seed = -1 #@param{type:\"integer\"}\n",
|
||||
"\n",
|
||||
"#@markdown * use this option to configure a sample_prompts.json\n",
|
||||
"#@markdown * check out /content/EveryDream2trainer/doc/logging.md. for more details\n",
|
||||
"Advance_Samples = False #@param{type:\"boolean\"}\n",
|
||||
"Sample_File = \"sample_prompts.txt\"\n",
|
||||
"if Advance_Samples:\n",
|
||||
" Sample_File = \"sample_prompts.json\"\n",
|
||||
"#@markdown * Checkpointing Saves Vram to allow larger batch sizes minor slow down on a single batch size but will can allow room for a higher traning resolution (suggested on Colab Free tier, turn off for A100)\n",
|
||||
"Gradient_checkpointing = True #@param{type:\"boolean\"}\n",
|
||||
"Disable_Xformers = False #@param{type:\"boolean\"}\n",
|
||||
|
@ -405,7 +410,7 @@
|
|||
" --max_epochs $Max_Epochs \\\n",
|
||||
" --project_name \"$Project_Name\" \\\n",
|
||||
" --resolution $Resolution \\\n",
|
||||
" --sample_prompts \"sample_prompts.txt\" \\\n",
|
||||
" --sample_prompts \"$Sample_File\" \\\n",
|
||||
" --sample_steps $Steps_between_samples \\\n",
|
||||
" --save_every_n_epoch $Save_every_N_epoch \\\n",
|
||||
" --seed $Training_Seed \\\n",
|
||||
|
@ -501,4 +506,4 @@
|
|||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
|
@ -23,7 +23,7 @@ from utils.isolate_rng import isolate_rng
|
|||
|
||||
def get_random_split(items: list[ImageTrainItem], split_proportion: float, batch_size: int) \
|
||||
-> tuple[list[ImageTrainItem], list[ImageTrainItem]]:
|
||||
split_item_count = math.ceil(split_proportion * len(items) // batch_size) * batch_size
|
||||
split_item_count = math.ceil(split_proportion * len(items) / batch_size) * batch_size
|
||||
# sort first, then shuffle, to ensure determinate outcome for the current random state
|
||||
items_copy = list(sorted(items, key=lambda i: i.pathname))
|
||||
random.shuffle(items_copy)
|
||||
|
|
6
train.py
6
train.py
|
@ -711,6 +711,10 @@ def main(args):
|
|||
|
||||
return model_pred, target
|
||||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
validator.do_validation_if_appropriate(epoch=0, global_step=0,
|
||||
get_model_prediction_and_target_callable=get_model_prediction_and_target)
|
||||
|
||||
try:
|
||||
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
|
||||
|
@ -849,7 +853,7 @@ def main(args):
|
|||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
|
||||
if validator:
|
||||
validator.do_validation_if_appropriate(epoch, global_step, get_model_prediction_and_target)
|
||||
validator.do_validation_if_appropriate(epoch+1, global_step, get_model_prediction_and_target)
|
||||
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
|
|
@ -8,16 +8,21 @@ from typing import Optional
|
|||
from tqdm.auto import tqdm
|
||||
|
||||
IMAGE_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.bmp', '.webp', '.jfif']
|
||||
|
||||
CAPTION_EXTENSIONS = ['.txt', '.caption', '.yaml', '.yml']
|
||||
|
||||
def gather_captioned_images(root_dir: str) -> list[tuple[str,Optional[str]]]:
|
||||
for directory, _, filenames in os.walk(root_dir):
|
||||
image_filenames = [f for f in filenames if os.path.splitext(f)[1].lower() in IMAGE_EXTENSIONS]
|
||||
for image_filename in image_filenames:
|
||||
caption_filename = os.path.splitext(image_filename)[0] + '.txt'
|
||||
image_path = os.path.join(directory+image_filename)
|
||||
caption_path = os.path.join(directory+caption_filename)
|
||||
yield (image_path, caption_path if os.path.exists(caption_path) else None)
|
||||
image_path = os.path.join(directory, image_filename)
|
||||
image_path_without_extension = os.path.splitext(image_path)[0]
|
||||
caption_path = None
|
||||
for caption_extension in CAPTION_EXTENSIONS:
|
||||
possible_caption_path = image_path_without_extension + caption_extension
|
||||
if os.path.exists(possible_caption_path):
|
||||
caption_path = possible_caption_path
|
||||
break
|
||||
yield image_path, caption_path
|
||||
|
||||
|
||||
def copy_captioned_image(image_caption_pair: tuple[str, Optional[str]], source_root: str, target_root: str):
|
||||
|
@ -39,13 +44,13 @@ if __name__ == '__main__':
|
|||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('source_root', type=str, help='Source root folder containing images')
|
||||
parser.add_argument('--train_output_folder', type=str, required=True, help="Output folder for the 'train' dataset")
|
||||
parser.add_argument('--val_output_folder', type=str, required=True, help="Output folder for the 'val' dataset")
|
||||
parser.add_argument('--train_output_folder', type=str, required=False, help="Output folder for the 'train' dataset. If omitted, do not save the train split.")
|
||||
parser.add_argument('--val_output_folder', type=str, required=True, help="Output folder for the 'val' dataset.")
|
||||
parser.add_argument('--split_proportion', type=float, required=True, help="Proportion of images to use for 'val' (a number between 0 and 1)")
|
||||
parser.add_argument('--seed', type=int, required=False, default=555, help='Random seed for shuffling')
|
||||
args = parser.parse_args()
|
||||
|
||||
images = gather_captioned_images(args.source_root)
|
||||
images = list(gather_captioned_images(args.source_root))
|
||||
print(f"Found {len(images)} captioned images in {args.source_root}")
|
||||
val_split_count = math.ceil(len(images) * args.split_proportion)
|
||||
if val_split_count == 0:
|
||||
|
@ -59,7 +64,9 @@ if __name__ == '__main__':
|
|||
print(f"Copying 'val' set to {args.val_output_folder}...")
|
||||
for v in tqdm(val_split):
|
||||
copy_captioned_image(v, args.source_root, args.val_output_folder)
|
||||
print(f"Copying 'train' set to {args.train_output_folder}...")
|
||||
for v in tqdm(train_split):
|
||||
copy_captioned_image(v, args.source_root, args.train_output_folder)
|
||||
|
||||
if args.train_output_folder is not None:
|
||||
print(f"Copying 'train' set to {args.train_output_folder}...")
|
||||
for v in tqdm(train_split):
|
||||
copy_captioned_image(v, args.source_root, args.train_output_folder)
|
||||
print("Done.")
|
Loading…
Reference in New Issue