Merge branch 'main' into refactor-data-resolution

This commit is contained in:
Joel Holdbrooks 2023-01-23 08:43:23 -08:00
commit b6c7299baf
6 changed files with 100 additions and 30 deletions

View File

@ -12,12 +12,17 @@
] ]
}, },
{ {
"attachments": {},
"cell_type": "markdown", "cell_type": "markdown",
"metadata": { "metadata": {
"id": "blaLMSbkPHhG" "id": "blaLMSbkPHhG"
}, },
"source": [ "source": [
"EveryDream2 Colab Edition" "# EveryDream2 Colab Edition\n",
"\n",
"Check out documentation here: https://github.com/victorchall/EveryDream2trainer#docs\n",
"\n",
"And join the discord: https://discord.gg/uheqxU6sXN"
] ]
}, },
{ {
@ -225,7 +230,7 @@
"#@title Resume from a diffusers model saved to your Gdrive\n", "#@title Resume from a diffusers model saved to your Gdrive\n",
"#@markdown * if you have preveiously saved diffusers on your drive you can slect them here\n", "#@markdown * if you have preveiously saved diffusers on your drive you can slect them here\n",
"#@markdown ex. */content/drive/MyDrive/everydreamlogs/myproject_202208/ckpts/interrupted-gs023*\n", "#@markdown ex. */content/drive/MyDrive/everydreamlogs/myproject_202208/ckpts/interrupted-gs023*\n",
"Resume_Model = \"\" #@param{type:\"string\"} \n", "Resume_Model = \"/content/drive/MyDrive/everydreamlogs/ckpt/SD15\" #@param{type:\"string\"} \n",
"save_name= Resume_Model" "save_name= Resume_Model"
] ]
}, },
@ -268,7 +273,8 @@
"#@markdown * Name your project so you can find it in your logs\n", "#@markdown * Name your project so you can find it in your logs\n",
"Project_Name = \"my_project\" #@param{type: 'string'}\n", "Project_Name = \"my_project\" #@param{type: 'string'}\n",
"\n", "\n",
"#@markdown * The learning rate affects how much \"training\" is done on the model per training step. It is a very careful balance to select a value that will learn your data. See Advanced Tweaking for more info. Once you have started, the learning rate is a good first knob to turn as you move into more advanced tweaking.\n", "#@markdown * The learning rate affects how much \"training\" is done on the model per training step. It is a very careful balance to select a value that will learn your data and not wreck the model. \n",
"#@markdown Leave this default unless you are very comfortable with training and know what you are doing.\n",
"\n", "\n",
"Learning_Rate = 1e-6 #@param{type: 'number'}\n", "Learning_Rate = 1e-6 #@param{type: 'number'}\n",
"\n", "\n",
@ -294,6 +300,8 @@
"#@markdown * Remember more gradient accumulation (or batch size) doesn't automatically mean better\n", "#@markdown * Remember more gradient accumulation (or batch size) doesn't automatically mean better\n",
"\n", "\n",
"Gradient_steps = 1 #@param{type:\"slider\", min:1, max:10, step:1}\n", "Gradient_steps = 1 #@param{type:\"slider\", min:1, max:10, step:1}\n",
"\n",
"#@markdown * Location on your Gdrive where your training images are.\n",
"Dataset_Location = \"/content/drive/MyDrive/training_samples\" #@param {type:\"string\"}\n", "Dataset_Location = \"/content/drive/MyDrive/training_samples\" #@param {type:\"string\"}\n",
"dataset = Dataset_Location\n", "dataset = Dataset_Location\n",
"model = save_name\n", "model = save_name\n",
@ -309,8 +317,11 @@
"\n", "\n",
"#@markdown You can set your own sample prompts by adding them, one line at a time, to `/content/EveryDream2trainer/sample_prompts.txt`. If left empty, it will use the captions from your training images.\n", "#@markdown You can set your own sample prompts by adding them, one line at a time, to `/content/EveryDream2trainer/sample_prompts.txt`. If left empty, it will use the captions from your training images.\n",
"\n", "\n",
"#@markdown Use the steps_between_samples to set how often the samples are generated.\n",
"Steps_between_samples = 300 #@param{type:\"integer\"}\n", "Steps_between_samples = 300 #@param{type:\"integer\"}\n",
"\n", "\n",
"#@markdown * That's it! Run the cell!\n",
"\n",
"Drive=\"\"\n", "Drive=\"\"\n",
"if Save_to_Gdrive:\n", "if Save_to_Gdrive:\n",
" Drive = \"--logdir /content/drive/MyDrive/everydreamlogs --save_ckpt_dir /content/drive/MyDrive/everydreamlogs/ckpt\"\n", " Drive = \"--logdir /content/drive/MyDrive/everydreamlogs --save_ckpt_dir /content/drive/MyDrive/everydreamlogs/ckpt\"\n",
@ -347,7 +358,7 @@
" $DX \\\n", " $DX \\\n",
" --amp \\\n", " --amp \\\n",
" --batch_size $Batch_Size \\\n", " --batch_size $Batch_Size \\\n",
" --grad_accum 2 \\\n", " --grad_accum $Gradient_steps \\\n",
" --cond_dropout 0.00 \\\n", " --cond_dropout 0.00 \\\n",
" --data_root \"$dataset\" \\\n", " --data_root \"$dataset\" \\\n",
" --flip_p 0.00 \\\n", " --flip_p 0.00 \\\n",
@ -376,8 +387,21 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"#@title Alternate startup script\n", "#@title Alternate startup script\n",
"#@markdown Edit train.json to setup your paramaters\n", "#@markdown * Edit train.json to setup your paramaters\n",
"!python train.py --config train.json" "#@markdown * Edit chain0.json to make use of chaining\n",
"#@markdown * make sure to check each confguration you will need 1 Json per chain length 3 are provided\n",
"\n",
"\n",
"%cd /content/EveryDream2trainer\n",
"Chain_Length=0 #@param{type:\"integer\"}\n",
"l = Chain_Length \n",
"I=0 #repeat counter\n",
"if l == None or l == 0:\n",
" l=1\n",
"while l > 0:\n",
" !python train_colab.py --config chain{I}.json\n",
" l -= 1\n",
" I =+ 1"
] ]
} }
], ],

View File

@ -17,6 +17,7 @@ import bisect
import math import math
import os import os
import logging import logging
import copy
import random import random
from data.image_train_item import ImageTrainItem from data.image_train_item import ImageTrainItem
@ -46,7 +47,6 @@ class DataLoaderMultiAspect():
self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False) self.aspects = aspects.get_aspect_buckets(resolution=resolution, square_only=False)
logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}") logging.info(f"* DLMA resolution {resolution}, buckets: {self.aspects}")
logging.info(" Preloading images...")
self.__prepare_train_data() self.__prepare_train_data()
(self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings() (self.rating_overall_sum, self.ratings_summed) = self.__sort_and_precalc_image_ratings()
@ -55,32 +55,37 @@ class DataLoaderMultiAspect():
""" """
Deals with multiply.txt whole and fractional numbers Deals with multiply.txt whole and fractional numbers
""" """
prepared_train_data_local = self.prepared_train_data.copy() #print(f"Picking multiplied set from {len(self.prepared_train_data)}")
data_copy = copy.deepcopy(self.prepared_train_data) # deep copy to avoid modifying original multiplier property
epoch_size = len(self.prepared_train_data) epoch_size = len(self.prepared_train_data)
picked_images = [] picked_images = []
# add by whole number part first and decrement multiplier in copy # add by whole number part first and decrement multiplier in copy
for iti in prepared_train_data_local: for iti in data_copy:
#print(f"check for whole number {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}")
while iti.multiplier >= 1.0: while iti.multiplier >= 1.0:
picked_images.append(iti) picked_images.append(iti)
iti.multiplier -= 1 #print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {iti.multiplier}, , datalen: {len(picked_images)}")
if iti.multiplier == 0: iti.multiplier -= 1.0
prepared_train_data_local.remove(iti)
remaining = epoch_size - len(picked_images) remaining = epoch_size - len(picked_images)
assert remaining >= 0, "Something went wrong with the multiplier calculation" assert remaining >= 0, "Something went wrong with the multiplier calculation"
#print(f"Remaining to fill epoch after whole number adds: {remaining}")
#print(f"Remaining in data copy: {len(data_copy)}")
# add by renaming fractional numbers by random chance # add by renaming fractional numbers by random chance
while remaining > 0: while remaining > 0:
for iti in prepared_train_data_local: for iti in data_copy:
if randomizer.uniform(0.0, 1) < iti.multiplier: if randomizer.uniform(0.0, 1.0) < iti.multiplier:
#print(f"Adding {iti.multiplier}: {iti.pathname}, remaining {remaining}, datalen: {len(data_copy)}")
picked_images.append(iti) picked_images.append(iti)
remaining -= 1 remaining -= 1
prepared_train_data_local.remove(iti) data_copy.remove(iti)
if remaining <= 0: if remaining <= 0:
break break
del data_copy
return picked_images return picked_images
def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0): def get_shuffled_image_buckets(self, dropout_fraction: float = 1.0):
@ -150,9 +155,17 @@ class DataLoaderMultiAspect():
""" """
Create ImageTrainItem objects with metadata for hydration later Create ImageTrainItem objects with metadata for hydration later
""" """
if not self.has_scanned: if not self.has_scanned:
self.has_scanned = True self.has_scanned = True
self.prepared_train_data, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p)
logging.info(" Preloading images...")
items, events = resolver.resolve(self.data_root, self.aspects, flip_p=flip_p, seed=self.seed)
image_paths = set(map(lambda item: item.pathname, items))
print (f" * DLMA: {len(items)} images loaded from {len(image_paths)} files")
random.Random(self.seed).shuffle(self.prepared_train_data) random.Random(self.seed).shuffle(self.prepared_train_data)
self.__report_undersized_images(events) self.__report_undersized_images(events)

View File

@ -71,7 +71,6 @@ class EveryDreamBatch(Dataset):
self.rated_dataset = rated_dataset self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target self.rated_dataset_dropout_target = rated_dataset_dropout_target
if seed == -1: if seed == -1:
seed = random.randint(0, 99999) seed = random.randint(0, 99999)

View File

@ -105,6 +105,10 @@ class ImageCaption:
caption += ", " caption += ", "
caption += tag caption += tag
if caption:
caption += ", "
caption += tag
return caption return caption
@staticmethod @staticmethod

View File

@ -1,6 +1,7 @@
import json import json
import logging import logging
import os import os
import random
import typing import typing
import zipfile import zipfile
@ -25,7 +26,7 @@ class UndersizedImageEvent(Event):
self.target_size = target_size self.target_size = target_size
class DataResolver: class DataResolver:
def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0): def __init__(self, aspects: list[typing.Tuple[int, int]], flip_p=0.0, seed=555):
self.aspects = aspects self.aspects = aspects
self.flip_p = flip_p self.flip_p = flip_p
self.events = [] self.events = []
@ -141,6 +142,7 @@ class DirectoryResolver(DataResolver):
items = [] items = []
multipliers = {} multipliers = {}
skip_folders = [] skip_folders = []
randomizer = random.Random(self.seed)
for pathname in tqdm.tqdm(image_paths): for pathname in tqdm.tqdm(image_paths):
current_dir = os.path.dirname(pathname) current_dir = os.path.dirname(pathname)
@ -164,6 +166,16 @@ class DirectoryResolver(DataResolver):
caption = ImageCaption.resolve(pathname) caption = ImageCaption.resolve(pathname)
item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir]) item = self.image_train_item(pathname, caption, multiplier=multipliers[current_dir])
cur_file_multiplier = multipliers[current_dir]
while cur_file_multiplier >= 1.0:
items.append(item)
cur_file_multiplier -= 1
if cur_file_multiplier > 0:
if randomizer.random() < cur_file_multiplier:
items.append(item)
if item: if item:
items.append(item) items.append(item)
return items return items
@ -210,17 +222,17 @@ def strategy(data_root: str):
raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.") raise ValueError(f"data_root '{data_root}' is not a valid directory or JSON file.")
def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0) -> typing.Tuple[list[ImageTrainItem], list[Event]]: def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0, seed) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
""" """
:param data_root: Directory or JSON file. :param data_root: Directory or JSON file.
:param aspects: The list of aspect ratios to use :param aspects: The list of aspect ratios to use
:param flip_p: The probability of flipping the image :param flip_p: The probability of flipping the image
""" """
if os.path.isfile(path) and path.endswith('.json'): if os.path.isfile(path) and path.endswith('.json'):
resolver = JSONResolver(aspects, flip_p) resolver = JSONResolver(aspects, flip_p, seed)
if os.path.isdir(path): if os.path.isdir(path):
resolver = DirectoryResolver(aspects, flip_p) resolver = DirectoryResolver(aspects, flip_p, seed)
if not resolver: if not resolver:
raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.") raise ValueError(f"data_root '{path}' is not a valid directory or JSON file.")
@ -229,7 +241,7 @@ def resolve_root(path: str, aspects: list[float], flip_p: float = 0.0) -> typing
events = resolver.events events = resolver.events
return items, events return items, events
def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0) -> typing.Tuple[list[ImageTrainItem], list[Event]]: def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=0.0, seed=555) -> typing.Tuple[list[ImageTrainItem], list[Event]]:
""" """
Resolve the training data from the value. Resolve the training data from the value.
:param value: The value to resolve, either a dict or a string. :param value: The value to resolve, either a dict or a string.
@ -244,12 +256,12 @@ def resolve(value: typing.Union[dict, str], aspects: list[float], flip_p: float=
match resolver: match resolver:
case 'directory' | 'json': case 'directory' | 'json':
path = value.get('path', None) path = value.get('path', None)
return resolve_root(path, aspects, flip_p) return resolve_root(path, aspects, flip_p, seed)
case 'multi': case 'multi':
resolved_items = [] resolved_items = []
resolved_events = [] resolved_events = []
for resolver in value.get('resolvers', []): for resolver in value.get('resolvers', []):
items, events = resolve(resolver, aspects, flip_p) items, events = resolve(resolver, aspects, flip_p, seed)
resolved_items.extend(items) resolved_items.extend(items)
resolved_events.extend(events) resolved_events.extend(events)
return resolved_items, resolved_events return resolved_items, resolved_events

View File

@ -1,5 +1,5 @@
""" """
Copyright [2022] Victor C Hall Copyright [2022-2023] Victor C Hall
Licensed under the GNU Affero General Public License; Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License. You may not use this code except in compliance with the License.
@ -343,7 +343,7 @@ def main(args):
logging.info(f" * Saving SD model to {sd_ckpt_full}") logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half) converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
if yaml_name: if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml" yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}") logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path) shutil.copyfile(yaml_name, yaml_save_path)
@ -367,7 +367,6 @@ def main(args):
safety_checker=None, # save vram safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker feature_extractor=None, # must be none of no safety checker
disable_tqdm=True,
) )
return pipe return pipe
@ -410,6 +409,8 @@ def main(args):
generates samples at different cfg scales and saves them to disk generates samples at different cfg scales and saves them to disk
""" """
logging.info(f"Generating samples gs:{gs}, for {prompts}") logging.info(f"Generating samples gs:{gs}, for {prompts}")
pipe.set_progress_bar_config(disable=True)
seed = args.seed if args.seed != -1 else random.randint(0, 2**30) seed = args.seed if args.seed != -1 else random.randint(0, 2**30)
gen = torch.Generator(device=device).manual_seed(seed) gen = torch.Generator(device=device).manual_seed(seed)
@ -588,7 +589,7 @@ def main(args):
""" """
print(f" {Fore.LIGHTGREEN_EX}** Welcome to EveryDream trainer 2.0!**{Style.RESET_ALL}") print(f" {Fore.LIGHTGREEN_EX}** Welcome to EveryDream trainer 2.0!**{Style.RESET_ALL}")
print(f" (C) 2022 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html") print(f" (C) 2022-2023 Victor C Hall This program is licensed under AGPL 3.0 https://www.gnu.org/licenses/agpl-3.0.en.html")
print() print()
print("** Trainer Starting **") print("** Trainer Starting **")
@ -691,6 +692,24 @@ def main(args):
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct" assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
try: try:
# # dummy batch to pin memory to avoid fragmentation in torch, uses square aspect which is maximum bytes size per aspects.py
# pixel_values = torch.randn_like(torch.zeros([args.batch_size, 3, args.resolution, args.resolution]))
# pixel_values = pixel_values.to(unet.device)
# with autocast(enabled=args.amp):
# latents = vae.encode(pixel_values, return_dict=False)
# latents = latents[0].sample() * 0.18215
# noise = torch.randn_like(latents)
# bsz = latents.shape[0]
# timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
# timesteps = timesteps.long()
# noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# cuda_caption = torch.linspace(100,177, steps=77, dtype=int).to(text_encoder.device)
# encoder_hidden_states = text_encoder(cuda_caption, output_hidden_states=True).last_hidden_state
# with autocast(enabled=args.amp):
# model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# # discard the grads, just want to pin memory
# optimizer.zero_grad(set_to_none=True)
for epoch in range(args.max_epochs): for epoch in range(args.max_epochs):
loss_epoch = [] loss_epoch = []
epoch_start_time = time.time() epoch_start_time = time.time()
@ -801,7 +820,6 @@ def main(args):
if (global_step + 1) % args.sample_steps == 0: if (global_step + 1) % args.sample_steps == 0:
pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae) pipe = __create_inference_pipe(unet=unet, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=sample_scheduler, vae=vae)
pipe = pipe.to(device) pipe = pipe.to(device)
#pipe.set_progress_bar_config(progress_bar=False)
with torch.no_grad(): with torch.no_grad():
if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1: if sample_prompts is not None and len(sample_prompts) > 0 and len(sample_prompts[0]) > 1: