This commit is contained in:
Victor Hall 2023-06-10 16:27:11 -04:00
commit 0e186a2a12
16 changed files with 500 additions and 273 deletions

2
.gitignore vendored
View File

@ -14,3 +14,5 @@
.ssh_config
*inference*.yaml
.idea
/.cache
/models

View File

@ -16,35 +16,79 @@
"id": "blaLMSbkPHhG"
},
"source": [
"# EveryDream2 Colab Edition\n",
"<p align=\"center\">\n",
" <img src=\"https://github.com/victorchall/EveryDream2trainer/blob/562c4341137d1d9f5bf525e6c56fb4b1eefa2b57/doc/ed_logo_comp.jpg?raw=true\" width=\"600\" height=\"300\">\n",
"</p>\n",
"\n",
"Check out documentation here: https://github.com/victorchall/EveryDream2trainer#docs\n",
"<br>\n",
"\n",
"And join the discord: https://discord.gg/uheqxU6sXN"
"---\n",
"\n",
"<div align=\"center\">\n",
" <font size=\"6\" color=\"yellow\">Colab Edition</font>\n",
"</div>\n",
"\n",
"---\n",
"\n",
"<br>\n",
"\n",
"Check out the **EveryDream2trainer** documentation and runpod/vastai and local setups here: \n",
"\n",
"[📚 **Documentation**](https://github.com/victorchall/EveryDream2trainer#docs)\n",
"\n",
"And join our vibrant community on Discord:\n",
"\n",
"[💬 **Join the Discord**](https://discord.gg/uheqxU6sXN)\n",
"\n",
"If you find this tool useful, please consider subscribing to the project on Patreon or making a one-time donation on Ko-fi. Your donations keep this project alive as a free open-source tool with ongoing enhancements.\n",
"\n",
"<br>\n",
"\n",
"<p align=\"center\">\n",
" <a href=\"https://www.patreon.com/everydream\">\n",
" <img src=\"https://github.com/victorchall/EveryDream2trainer/raw/main/.github/patreon-medium-button.png?raw=true\" width=\"200\" height=\"50\">\n",
" </a>\n",
"</p>\n",
"\n",
"<br>\n",
"\n",
"<p align=\"center\">\n",
" <a href=\"https://ko-fi.com/everydream\">\n",
" <img src=\"https://github.com/victorchall/EveryDream2trainer/raw/main/.github/kofibutton_sm.png?raw=true\" width=\"75\" height=\"75\">\n",
" </a>\n",
"</p>\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "hAuBbtSvGpau",
"cellView": "form"
"cellView": "form",
"id": "hAuBbtSvGpau"
},
"outputs": [],
"source": [
"#@markdown # Setup and Install Dependencies\n",
"from IPython.display import clear_output\n",
"from IPython.display import clear_output, display, HTML\n",
"import subprocess\n",
"from tqdm.auto import tqdm\n",
"import time\n",
"import os \n",
"from tqdm.auto import tqdm\n",
"import PIL\n",
"\n",
"# Defining function for colored text\n",
"def colored(r, g, b, text):\n",
" return f\"\\033[38;2;{r};{g};{b}m{text} \\033[38;2;255;255;255m\"\n",
"\n",
"#@markdown Optional connect Gdrive But strongly recommended\n",
"#@markdown This will let you put all your training data and checkpoints directly on your drive. Much faster/easier to continue later, less setup time.\n",
"#@markdown This will let you put all your training data and checkpoints directly on your drive. \n",
"#@markdown Much faster/easier to continue later, less setup time.\n",
"\n",
"#@markdown Creates /content/drive/MyDrive/everydreamlogs/ckpt\n",
"Mount_to_Gdrive = True #@param{type:\"boolean\"} \n",
"\n",
"# Clone the git repository\n",
"print(colored(0, 255, 0, 'Cloning git repository...'))\n",
"!git clone https://github.com/victorchall/EveryDream2trainer.git\n",
"\n",
"if Mount_to_Gdrive:\n",
@ -55,35 +99,37 @@
"\n",
"%cd /content/EveryDream2trainer\n",
"\n",
"# Download Arial font\n",
"print(colored(0, 255, 0, 'Downloading Arial font...'))\n",
"!wget -O arial.ttf https://raw.githubusercontent.com/matomo-org/travis-scripts/master/fonts/Arial.ttf\n",
"\n",
"!cp /content/arial.ttf /usr/share/fonts/truetype/\n",
"\n",
"\n",
"packages = [\n",
" 'transformers==4.27.1',\n",
" 'transformers==4.29.2',\n",
" 'diffusers[torch]==0.14.0',\n",
" 'pynvml==11.4.1',\n",
" 'bitsandbytes==0.37.2',\n",
" 'ftfy==6.1.1',\n",
" 'aiohttp==3.8.4',\n",
" 'compel~=1.1.3',\n",
" 'protobuf==3.20.3',\n",
" 'wandb==0.13.6',\n",
" 'pyre-extensions==0.0.23',\n",
" '--no-deps xformers==0.0.19',\n",
" 'pytorch-lightning==1.9.2',\n",
" 'protobuf==3.20.1',\n",
" 'wandb==0.15.3',\n",
" 'pyre-extensions==0.0.29',\n",
" 'xformers==0.0.20',\n",
" 'pytorch-lightning==1.6.5',\n",
" 'OmegaConf==2.2.3',\n",
" 'tensorboard>=2.11.0',\n",
" 'tensorrt'\n",
" 'wandb',\n",
" 'colorama',\n",
" 'keyboard',\n",
" 'lion-pytorch'\n",
"]\n",
"\n",
"for package in tqdm(packages, desc='Installing packages', unit='package'):\n",
"print(colored(0, 255, 0, 'Installing packages...'))\n",
"for package in tqdm(packages, desc='Installing packages', unit='package', bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}'):\n",
" if isinstance(package, tuple):\n",
" package_name, extra_index_url = package\n",
" cmd = f\"pip install -q {package_name} --extra-index-url {extra_index_url}\"\n",
" cmd = f\"pip install -I -q {package_name} --extra-index-url {extra_index_url}\"\n",
" else:\n",
" cmd = f\"pip install -q {package}\"\n",
" \n",
@ -92,28 +138,62 @@
"clear_output()\n",
"\n",
"\n",
"\n",
"# Execute Python script\n",
"print(colored(0, 255, 0, 'Executing Python script...'))\n",
"!python utils/get_yamls.py\n",
"clear_output()\n",
"## ty Google for cutting out install time by 50%\n",
"print(\"DONE! installing dependencies.\")\n",
"GPU = !nvidia-smi\n",
"print(\"GPU details:\")\n",
"for line in GPU:\n",
" print(line)\n",
"\n",
"print(colored(0, 255, 0, \"DONE! installing dependencies.\"))\n",
"\n",
"# Import pynvml and get GPU details\n",
"import pynvml\n",
"\n",
"pynvml.nvmlInit()\n",
"\n",
"handle = pynvml.nvmlDeviceGetHandleByIndex(0)\n",
"\n",
"gpu_name = pynvml.nvmlDeviceGetName(handle)\n",
"gpu_memory = pynvml.nvmlDeviceGetMemoryInfo(handle).total / 1024**3\n",
"cuda_version_number = pynvml.nvmlSystemGetCudaDriverVersion_v2()\n",
"cuda_version_major = cuda_version_number // 1000\n",
"cuda_version_minor = (cuda_version_number % 1000) // 10\n",
"cuda_version = f\"{cuda_version_major}.{cuda_version_minor}\"\n",
"\n",
"pynvml.nvmlShutdown()\n",
"\n",
"Python_version = !python --version\n",
"print(\"\\nPython version:\")\n",
"print(Python_version[0])\n",
"import torch\n",
"print(\"\\nPyTorch version:\")\n",
"print(torch.__version__)\n",
"import torchvision\n",
"print(\"\\nTorchvision version:\")\n",
"print(torchvision.__version__)\n",
"import xformers\n",
"print(\"\\nXFormers version:\")\n",
"print(xformers.__version__)\n",
"time.sleep(2)\n"
"\n",
"display(HTML(f\"\"\"\n",
"<table style=\"background-color:transparent;\">\n",
" <tr>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">Python version:</span></td>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{Python_version[0]}</span></td>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">GPU Name:</span></td>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{gpu_name}</span></td>\n",
" </tr>\n",
" <tr>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">PyTorch version:</span></td>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{torch.__version__}</span></td>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">GPU Memory (GB):</span></td>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{gpu_memory:.2f}</span></td>\n",
" </tr>\n",
" <tr>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">Torchvision version:</span></td>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{torchvision.__version__}</span></td>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">CUDA version:</span></td>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{cuda_version}</span></td>\n",
" </tr>\n",
" <tr>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">XFormers version:</span></td>\n",
" <td style=\"background-color:transparent;\"><span style=\"color: #FFFF00;\">{xformers.__version__}</span></td>\n",
" </tr>\n",
"</table>\n",
"\"\"\"))\n",
"\n",
"time.sleep(2)"
]
},
{
@ -127,43 +207,45 @@
"source": [
"#@title Get A Base Model\n",
"#@markdown Choose SD1.5, Waifu Diffusion 1.3, SD2.1, or 2.1(512) from the dropdown, or paste your own URL in the box\n",
"#@markdown * Alternately you can link to a HF repo using NAME/MODEL, this does not save to your Gdrive, if you want to save an hf model use the direct url\n",
"#@markdown * Alternately you can link to an HF repo using NAME/MODEL, this does not save to your Gdrive, if you want to save an HF model, use the direct URL\n",
"\n",
"#@markdown * Link to a set of diffusers on your Gdrive\n",
"\n",
"#@markdown * Paste a url, atm there is no support for .safetensors\n",
"#@markdown * Paste a URL, atm there is no support for .safetensors\n",
"\n",
"from IPython.display import clear_output\n",
"!mkdir input\n",
"%cd /content/EveryDream2trainer\n",
"MODEL_LOCATION = \"sd_v1-5+vae.ckpt\" #@param [\"sd_v1-5+vae.ckpt\", \"hakurei/waifu-diffusion-v1-3\", \"stabilityai/stable-diffusion-2-1-base\", \"stabilityai/stable-diffusion-2-1\"] {allow-input: true}\n",
"MODEL_LOCATION = \"panopstor/EveryDream\" #@param [\"sd_v1-5+vae.ckpt\", \"hakurei/waifu-diffusion-v1-3\", \"stabilityai/stable-diffusion-2-1-base\", \"stabilityai/stable-diffusion-2-1\"] {allow-input: true}\n",
"\n",
"if MODEL_LOCATION == \"sd_v1-5+vae.ckpt\":\n",
" MODEL_LOCATION = \"panopstor/EveryDream\"\n",
"Flag = False\n",
"\n",
"If_Ckpt = False\n",
"import os\n",
"\n",
"download_path = \"\"\n",
"\n",
"if \".co\" in MODEL_LOCATION or \"https\" in MODEL_LOCATION or \"www\" in MODEL_LOCATION: #maybe just add a radio button to download this should work for now\n",
" print(\"Downloading \")\n",
"if \".co\" in MODEL_LOCATION or \"https\" in MODEL_LOCATION or \"www\" in MODEL_LOCATION:\n",
" MODEL_URL = MODEL_LOCATION\n",
" print(\"Downloading...\")\n",
" !wget $MODEL_LOCATION\n",
" clear_output()\n",
" print(\"DONE!\")\n",
" print(\"Download completed!\")\n",
" download_path = os.path.join(os.getcwd(), os.path.basename(MODEL_URL))\n",
"\n",
"else:\n",
" save_name = MODEL_LOCATION\n",
" save_name = MODEL_LOCATION\n",
"\n",
"%cd /content/EveryDream2trainer\n",
"#@markdown * If you chose to link to a .ckpt Select the correct model version in the drop down menu for conversion\n",
"\n",
"inference_yaml = \" \"\n",
"\n",
"# Check if the downloaded or copied model is a .ckpt file\n",
"#@markdown Is the model 1.5 or 2.1 based?\n",
"model_type = \"SD1x\" #@param [\"SD1x\", \"SD2_512_base\", \"SD21\"]\n",
"\n",
"if download_path.endswith(\".ckpt\") or MODEL_LOCATION.endswith(\".ckpt\"):\n",
" Flag = True\n",
" model_type = \"SD1x\" #@param [\"SD1x\", \"SD2_512_base\", \"SD21\"]\n",
" If_Ckpt = True\n",
" save_path = download_path\n",
" if \".ckpt\" in save_name:\n",
" save_name = save_name.replace(\".ckpt\", \"\")\n",
@ -171,6 +253,7 @@
" img_size = 512\n",
" upscale_attention = False\n",
" prediction_type = \"epsilon\"\n",
"\n",
" if model_type == \"SD1x\":\n",
" inference_yaml = \"v1-inference.yaml\"\n",
" elif model_type == \"SD2_512_base\":\n",
@ -182,7 +265,7 @@
" inference_yaml = \"v2-inference-v.yaml\"\n",
" img_size = 768\n",
"\n",
" !python utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
" !python utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
" --original_config_file $inference_yaml \\\n",
" --image_size $img_size \\\n",
" --checkpoint_path $MODEL_LOCATION \\\n",
@ -190,13 +273,14 @@
" --upcast_attn False \\\n",
" --dump_path $save_name\n",
"\n",
" # Set the save path to the GDrive directory if cache_to_gdrive is True\n",
"# Set the save path to the GDrive directory if cache_to_gdrive is True\n",
"if If_Ckpt:\n",
" save_name = os.path.join(\"/content/drive/MyDrive/everydreamlogs/ckpt\", save_name)\n",
"\n",
"if Flag:\n",
" save_name = os.path.join(\"/content/drive/MyDrive/everydreamlogs/ckpt\", save_name)\n",
"if inference_yaml != \" \":\n",
" print(\"Model saved to: \" + save_name + \". The \" + inference_yaml + \" was used!\")\n",
"print(\"Model \" + save_name + \" will be used!\")\n"
" print(\"Model saved to: \" + save_name + \". The \" + inference_yaml + \" was used!\")\n",
"\n",
"print(\"Model \" + save_name + \" will be used!\")"
]
},
{
@ -245,20 +329,27 @@
"#@markdown * Name your project so you can find it in your logs\n",
"Project_Name = \"My_Project\" #@param{type: 'string'}\n",
"\n",
"# Load the JSON file\n",
"with open('optimizer.json', 'r') as file:\n",
"\n",
"\n",
"if model_type == 'SD2_512_base' or model_type == 'SD21':\n",
" file_path = \"/content/EveryDream2trainer/optimizerSD21.json\"\n",
"else:\n",
" file_path = \"/content/EveryDream2trainer/optimizer.json\"\n",
"\n",
"with open(file_path, 'r') as file:\n",
" data = json.load(file)\n",
"\n",
"\n",
"#@markdown * The learning rate affects how much \"training\" is done on the model per training step. It is a very careful balance to select a value that will learn your data and not wreck the model. \n",
"#@markdown Leave this default unless you are very comfortable with training and know what you are doing.\n",
"Learning_Rate = 1e-6 #@param{type: 'number'}\n",
"#@markdown * chosing this will allow you to ignore any settings specific to the text encode and will match it with the Unets settings, recommended for beginers.\n",
"Match_text_to_Unet = False #@param{type:\"boolean\"}\n",
"Text_lr = 0.5e-6 #@param {type:\"number\"}\n",
"Text_lr = 5e-7 #@param {type:\"number\"}\n",
"#@markdown * A learning rate scheduler can change your learning rate as training progresses.\n",
"#@markdown * I recommend sticking with constant until you are comfortable with general training. \n",
"Schedule = \"constant\" #@param [\"constant\", \"polynomial\", \"linear\", \"cosine\"] {allow-input: true}\n",
"Text_lr_scheduler = \"constant\" #@param [\"constant\", \"polynomial\", \"linear\", \"cosine\"] {allow-input: true}\n",
"Schedule = \"linear\" #@param [\"constant\", \"polynomial\", \"linear\", \"cosine\"] {allow-input: true}\n",
"Text_lr_scheduler = \"linear\" #@param [\"constant\", \"polynomial\", \"linear\", \"cosine\"] {allow-input: true}\n",
"#@markdown * warm up steps are useful for validation and cosine lrs\n",
"lr_warmup_steps = 0 #@param{type:\"integer\"}\n",
"lr_decay_steps = 0 #@param {type:\"number\"} \n",
@ -280,7 +371,7 @@
"data['text_encoder_overrides']['lr_decay_steps'] = Text_lr_decay_steps\n",
"\n",
"# Save the updated JSON data back to the file\n",
"with open('optimizer.json', 'w') as file:\n",
"with open(file_path, 'w') as file:\n",
" json.dump(data, file, indent=4)\n",
"\n",
"#@markdown * Resolution to train at (recommend 512). Higher resolution will require lower batch size (below).\n",
@ -302,11 +393,12 @@
"#@markdown * Location on your Gdrive where your training images are.\n",
"Dataset_Location = \"/content/drive/MyDrive/training_samples\" #@param {type:\"string\"}\n",
"\n",
"model = save_name\n",
"if not resume:\n",
" model = save_name\n",
"\n",
"#@markdown * Max Epochs to train for, this defines how many total times all your training data is used. Default of 100 is a good start if you are training ~30-40 images of one subject. If you have 100 images, you can reduce this to 40-50 and so forth.\n",
"\n",
"Max_Epochs = 100 #@param {type:\"slider\", min:0, max:200, step:5}\n",
"Max_Epochs = 100 #@param {type:\"slider\", min:0, max:200, step:1}\n",
"\n",
"#@markdown * How often to save checkpoints.\n",
"Save_every_N_epoch = 20 #@param{type:\"integer\"}\n",
@ -361,7 +453,7 @@
"#@markdown use validation with wandb\n",
"\n",
"validatation = False #@param{type:\"boolean\"}\n",
"\n",
"Hide_Warnings = False #@param {type:\"boolean\"}\n",
"\n",
"extensions = ['.zip', '.7z', '.rar', '.tgz']\n",
"uncompressed_dir = 'Training_Data'\n",
@ -394,6 +486,11 @@
"if validatation:\n",
" validate = \"--validation_config validation_default.json\"\n",
"\n",
"\n",
"if Hide_Warnings:\n",
" import warnings\n",
" warnings.filterwarnings(\"ignore\")\n",
"\n",
"wandb_settings = \"\"\n",
"if wandb_token:\n",
" !rm /root/.netrc\n",
@ -446,7 +543,7 @@
" --zero_frequency_noise_ratio $zero_frequency_noise\n",
"\n",
"# Finish the training process\n",
"clear_output()\n",
"# clear_output()\n",
"time.sleep(2)\n",
"print(\"Training is complete, select a model to start training again\")\n",
"time.sleep(2)\n",
@ -456,8 +553,7 @@
" time.sleep(40)\n",
" runtime.unassign()\n",
"\n",
"os.kill(os.getpid(), 9)\n",
"\n"
"os.kill(os.getpid(), 9)"
]
},
{
@ -519,15 +615,24 @@
},
{
"cell_type": "markdown",
"source": [
"## Optional NoteBook Features, read all the documentation in /content/EveryDream2trainer/doc before proceeding."
],
"metadata": {
"id": "fzXLJVC6OCeP"
}
},
"source": [
"## Optional NoteBook Features, read all the documentation in /content/EveryDream2trainer/doc before proceeding."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"colab": {
"background_save": true
},
"id": "BafdWaYymg0O"
},
"outputs": [],
"source": [
"#@title Remove logs for samples when training (optional) run before training\n",
"file_path = \"/content/EveryDream2trainer/utils/sample_generator.py\"\n",
@ -550,170 +655,8 @@
"with open(file_path, \"w\") as file:\n",
" file.write(content)\n",
"\n",
"print(\"The specified code block has been deleted.\")\n"
],
"metadata": {
"id": "BafdWaYymg0O",
"cellView": "form"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"#@title train.json Editor { display-mode: \"form\" }\n",
"#title json Editor for ED2\n",
"\n",
"import json\n",
"\n",
"data = {\n",
" \"disable_textenc_training\": False,\n",
" \"disable_xformers\": False,\n",
" \"disable_amp\": False,\n",
" \"save_optimizer\": False,\n",
" \"gradient_checkpointing\": True,\n",
" \"wandb\": False,\n",
" \"write_schedule\": False,\n",
" \"rated_dataset\": False,\n",
" \"batch_size\": 10,\n",
" \"ckpt_every_n_minutes\": None,\n",
" \"clip_grad_norm\": None,\n",
" \"clip_skip\": 0,\n",
" \"cond_dropout\": 0.04,\n",
" \"data_root\": \"X:\\\\my_project_data\\\\project_abc\",\n",
" \"flip_p\": 0.0,\n",
" \"gpuid\": 0,\n",
" \"grad_accum\": 1,\n",
" \"logdir\": \"logs\",\n",
" \"log_step\": 25,\n",
" \"lr\": 1.5e-06,\n",
" \"lr_decay_steps\": 0,\n",
" \"lr_scheduler\": \"constant\",\n",
" \"lr_warmup_steps\": None,\n",
" \"max_epochs\": 30,\n",
" \"optimizer_config\": \"optimizer.json\",\n",
" \"project_name\": \"project_abc\",\n",
" \"resolution\": 512,\n",
" \"resume_ckpt\": \"sd_v1-5_vae\",\n",
" \"run_name\": None,\n",
" \"sample_prompts\": \"sample_prompts.txt\",\n",
" \"sample_steps\": 300,\n",
" \"save_ckpt_dir\": None,\n",
" \"save_ckpts_from_n_epochs\": 0,\n",
" \"save_every_n_epochs\": 20,\n",
" \"seed\": 555,\n",
" \"shuffle_tags\": False,\n",
" \"validation_config\": \"validation_default.json\",\n",
" \"rated_dataset_target_dropout_percent\": 50,\n",
" \"zero_frequency_noise_ratio\": 0.02\n",
"}\n",
"\n",
"%cd /content/EveryDream2trainer\n",
"#@markdown JSON Parameters\n",
"findlast = \"\" \n",
"Resume_Last_Training_session = False #@param {type:\"boolean\"}\n",
"findlast == Resume_Last_Training_session\n",
"disable_textenc_training = False #@param {type:\"boolean\"}\n",
"data[\"disable_textenc_training\"] = disable_textenc_training\n",
"disable_xformers = False #@param {type:\"boolean\"}\n",
"data[\"disable_xformers\"] = disable_xformers\n",
"gradient_checkpointing = True #@param {type:\"boolean\"}\n",
"data[\"gradient_checkpointing\"] = gradient_checkpointing\n",
"save_optimizer = False #@param {type:\"boolean\"}\n",
"data[\"save_optimizer\"] = save_optimizer \n",
"scale_lr = False #@param {type:\"boolean\"}\n",
"data[\"scale_lr\"] = scale_lr\n",
"shuffle_tags = False #@param {type:\"boolean\"}\n",
"data[\"shuffle_tags\"] = shuffle_tags\n",
"wandb = False #@param {type:\"boolean\"}\n",
"data[\"wandb\"] = wandb\n",
"write_schedule = False #@param {type:\"boolean\"}\n",
"data[\"write_schedule\"] = write_schedule\n",
"rated_dataset = False #@param {type:\"boolean\"}\n",
"data[\"rated_dataset\"] = rated_dataset \n",
"batch_size = 8 #@param {type:\"integer\"}\n",
"data[\"batch_size\"] = batch_size\n",
"ckpt_every_n_minutes = None #@param {type:\"raw\"}\n",
"data[\"ckpt_every_n_minutes\"] = ckpt_every_n_minutes\n",
"clip_grad_norm = None #@param {type:\"raw\"}\n",
"data[\"clip_grad_norm\"] = clip_grad_norm\n",
"clip_skip = 0 #@param {type:\"integer\"}\n",
"data[\"clip_skip\"] = clip_skip\n",
"cond_dropout = 0.04 #@param {type:\"number\"}\n",
"data[\"cond_dropout\"] = cond_dropout\n",
"data_root = \"X:\\\\my_project_data\\\\project_abc\" #@param {type:\"string\"}\n",
"data[\"data_root\"] = data_root\n",
"flip_p = 0.0 #@param {type:\"number\"}\n",
"data[\"flip_p\"] = flip_p\n",
"grad_accum = 1 #@param {type:\"integer\"}\n",
"data[\"grad_accum\"] = grad_accum\n",
"logdir = \"logs\" #@param {type:\"string\"}\n",
"data[\"logdir\"] = logdir\n",
"log_step = 25 #@param {type:\"integer\"}\n",
"data[\"log_step\"] = log_step\n",
"lr = 1.5e-06 #@param {type:\"number\"}\n",
"data[\"lr\"] = lr\n",
"lr_decay_steps = 0 #@param {type:\"integer\"}\n",
"data[\"lr_decay_steps\"] = lr_decay_steps\n",
"lr_scheduler = \"constant\" #@param {type:\"string\"}\n",
"data[\"lr_scheduler\"] = lr_scheduler\n",
"lr_warmup_steps = None #@param {type:\"raw\"}\n",
"data[\"lr_warmup_steps\"] = lr_warmup_steps\n",
"max_epochs = 100 #@param {type:\"integer\"}\n",
"data[\"max_epochs\"] = max_epochs\n",
"optimizer_config = \"optimizer.json\" #@param {type:\"string\"}\n",
"data[\"optimizer_config\"] = optimizer_config\n",
"project_name = \"project_abc\" #@param {type:\"string\"}\n",
"data[\"project_name\"] = project_name\n",
"resolution = 512 #@param {type:\"integer\"}\n",
"data[\"resolution\"] = resolution\n",
"resume_ckpt = \"sd_v1-5_vae\" #@param {type:\"string\"}\n",
"if findlast:\n",
" resume_ckpt = \"findlast\"\n",
"data[\"resume_ckpt\"] = resume_ckpt\n",
"run_name = None #@param {type:\"raw\"}\n",
"data[\"run_name\"] = run_name\n",
"sample_prompts = \"sample_prompts.txt\" #@param [\"sample_prompts.txt\", \"sample_prompts.json\"]\n",
"data[\"sample_prompts\"] = sample_prompts\n",
"sample_steps = 300 #@param {type:\"integer\"}\n",
"data[\"sample_steps\"] = sample_steps\n",
"save_ckpt_dir = None #@param {type:\"raw\"}\n",
"data[\"save_ckpt_dir\"] = save_ckpt_dir\n",
"save_ckpts_from_n_epochs = 0 #@param {type:\"integer\"}\n",
"data[\"save_ckpts_from_n_epochs\"] = save_ckpts_from_n_epochs\n",
"save_every_n_epochs = 20 #@param {type:\"integer\"}\n",
"data[\"save_every_n_epochs\"] = save_every_n_epochs\n",
"seed = 555 #@param {type:\"integer\"}\n",
"data[\"seed\"] = seed\n",
"validation_config = \"validation_default.json\" #@param {type:\"string\"}\n",
"data[\"validation_config\"] = validation_config\n",
"rated_dataset_target_dropout_percent = 50 #@param {type:\"integer\"}\n",
"data[\"rated_dataset_target_dropout_percent\"] = rated_dataset_target_dropout_percent\n",
"zero_frequency_noise_ratio = 0.02 #@param {type:\"number\"}\n",
"data[\"zero_frequency_noise_ratio\"] = zero_frequency_noise_ratio\n",
"\n",
"\n",
"\n",
"# Display the modified JSON data\n",
"print(\"Modified JSON data:\")\n",
"print(json.dumps(data, indent=2))\n",
"\n",
"\n",
"# Save the modified JSON data to a file\n",
"filename = \"train.json\" #@param {type:\"string\"}\n",
"variable_name = \"\" #@param {type:\"string\"}\n",
"\n",
"with open(filename, 'w') as file:\n",
" json.dump(data, file, indent=2)\n",
"\n",
"print(f\"Modified JSON data saved to '{filename}'.\")"
],
"metadata": {
"id": "d20kz8EtosWM"
},
"execution_count": null,
"outputs": []
"print(\"The specified code block has been deleted.\")"
]
},
{
"cell_type": "code",
@ -725,7 +668,7 @@
"outputs": [],
"source": [
"#@title Alternate startup script\n",
"#@markdown * Edit train.json to setup your paramaters\n",
"#@markdown * Edit train.json or chain0.json to setup your paramaters\n",
"\n",
"#@markdown * Edit using a chain length of 0 will use train.json\n",
"\n",
@ -733,8 +676,6 @@
"\n",
"#@markdown * make sure to check each confguration you will need 1 Json per chain length 3 are provided\n",
"\n",
"#@markdown * make sure your .Json contain the line Notebook: true\n",
"\n",
"#@markdown * your locations in the .json can be done in this format /content/drive/MyDrive/ - then the sub folder you wish to use\n",
"\n",
"%cd /content/EveryDream2trainer\n",
@ -748,18 +689,26 @@
" l -= 1\n",
" I =+ 1"
]
},
{
"cell_type": "markdown",
"source": [
"Need some tools to Manage your large datasets check out https://github.com/victorchall/EveryDream for some usefull tools and captioner"
],
"metadata": {
"id": "ls6mX94trxZV"
}
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": [],
"gpuType": "T4",
"include_colab_link": true
},
"gpuClass": "standard",
"kernelspec": {
"display_name": "venv",
"language": "python",
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
@ -774,4 +723,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@ -248,7 +248,7 @@ def __get_all_aspects():
]
def get_rational_aspect_ratio(bucket_wh: Tuple[int]) -> Tuple[int]:
def get_rational_aspect_ratio(bucket_wh: Tuple[int, int]) -> Tuple[int]:
def farey_aspect_ratio_pair(x: float, max_denominator_value: int):
if x <= 1:
return farey_aspect_ratio_pair_lt1(x, max_denominator_value)

View File

@ -15,15 +15,17 @@ limitations under the License.
"""
import bisect
import logging
import os.path
from collections import defaultdict
import math
import copy
import random
from data.image_train_item import ImageTrainItem
from typing import List, Dict
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
import PIL.Image
from utils.first_fit_decreasing import first_fit_decreasing
PIL.Image.MAX_IMAGE_PIXELS = 715827880*4 # increase decompression bomb error limit to 4x default
class DataLoaderMultiAspect():
@ -34,9 +36,10 @@ class DataLoaderMultiAspect():
seed: random seed
batch_size: number of images per batch
"""
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1):
def __init__(self, image_train_items: list[ImageTrainItem], seed=555, batch_size=1, grad_accum=1):
self.seed = seed
self.batch_size = batch_size
self.grad_accum = grad_accum
self.prepared_train_data = image_train_items
random.Random(self.seed).shuffle(self.prepared_train_data)
self.prepared_train_data = sorted(self.prepared_train_data, key=lambda img: img.caption.rating())
@ -103,14 +106,18 @@ class DataLoaderMultiAspect():
buckets = {}
batch_size = self.batch_size
grad_accum = self.grad_accum
for image_caption_pair in picked_images:
image_caption_pair.runt_size = 0
target_wh = image_caption_pair.target_wh
if (target_wh[0],target_wh[1]) not in buckets:
buckets[(target_wh[0],target_wh[1])] = []
buckets[(target_wh[0],target_wh[1])].append(image_caption_pair)
bucket_key = (image_caption_pair.batch_id,
image_caption_pair.target_wh[0],
image_caption_pair.target_wh[1])
if bucket_key not in buckets:
buckets[bucket_key] = []
buckets[bucket_key].append(image_caption_pair)
# handle runts by randomly duplicating items
for bucket in buckets:
truncate_count = len(buckets[bucket]) % batch_size
if truncate_count > 0:
@ -125,13 +132,19 @@ class DataLoaderMultiAspect():
buckets[bucket] = buckets[bucket][:current_bucket_size - truncate_count]
buckets[bucket].extend(runt_bucket)
# flatten the buckets
items: list[ImageTrainItem] = []
for bucket in buckets:
items.extend(buckets[bucket])
# handle batch_id
# unlabelled data (no batch_id) is in batches labelled DEFAULT_BATCH_ID.
items_by_batch_id = collapse_buckets_by_batch_id(buckets)
items = flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id,
batch_size=batch_size,
grad_accum=grad_accum)
effective_batch_size = batch_size * grad_accum
items = chunked_shuffle(items, chunk_size=effective_batch_size, randomizer=randomizer)
return items
def __pick_random_subset(self, dropout_fraction: float, picker: random.Random) -> list[ImageTrainItem]:
"""
Picks a random subset of all images
@ -174,4 +187,53 @@ class DataLoaderMultiAspect():
self.ratings_summed: list[float] = []
for item in self.prepared_train_data:
self.rating_overall_sum += item.caption.rating()
self.ratings_summed.append(self.rating_overall_sum)
self.ratings_summed.append(self.rating_overall_sum)
def chunk(l: List, chunk_size) -> List:
num_chunks = int(math.ceil(float(len(l)) / chunk_size))
return [l[i * chunk_size:(i + 1) * chunk_size] for i in range(num_chunks)]
def unchunk(chunked_list: List):
return [i for c in chunked_list for i in c]
def collapse_buckets_by_batch_id(buckets: Dict) -> Dict:
batch_ids = [k[0] for k in buckets.keys()]
items_by_batch_id = {}
for batch_id in batch_ids:
items_by_batch_id[batch_id] = unchunk([b for bucket_key,b in buckets.items() if bucket_key[0] == batch_id])
return items_by_batch_id
def flatten_buckets_preserving_named_batch_adjacency(items_by_batch_id: Dict[str, List[ImageTrainItem]],
batch_size: int,
grad_accum: int) -> List[ImageTrainItem]:
# precondition: items_by_batch_id has no incomplete batches
assert(all((len(v) % batch_size)==0 for v in items_by_batch_id.values()))
# ensure we don't mix up aspect ratios by treating each chunk of batch_size images as
# a single unit to pass to first_fit_decreasing()
filler_items = chunk(items_by_batch_id.get(DEFAULT_BATCH_ID, []), batch_size)
custom_batched_items = [chunk(v, batch_size) for k, v in items_by_batch_id.items() if k != DEFAULT_BATCH_ID]
neighbourly_chunked_items = first_fit_decreasing(custom_batched_items,
batch_size=grad_accum,
filler_items=filler_items)
items: List[ImageTrainItem] = unchunk(neighbourly_chunked_items)
return items
def chunked_shuffle(l: List, chunk_size: int, randomizer: random.Random) -> List:
"""
Shuffles l in chunks, preserving the chunk boundaries and the order of items within each chunk.
If the last chunk is incomplete, it is not shuffled (i.e. preserved as the last chunk)
"""
# chunk by effective batch size
chunks = chunk(l, chunk_size)
# preserve last chunk as last if it is incomplete
last_chunk = None
if len(chunks[-1]) < chunk_size:
last_chunk = chunks.pop(-1)
randomizer.shuffle(chunks)
if last_chunk is not None:
chunks.append(last_chunk)
l = unchunk(chunks)
return l

View File

@ -1,10 +1,7 @@
import os
import logging
import yaml
import json
from functools import total_ordering
from attrs import define, field, Factory
from attrs import define, field
from data.image_train_item import ImageCaption, ImageTrainItem
from utils.fs_helpers import *
from typing import Iterable
@ -50,6 +47,7 @@ class ImageConfig:
rating: float = None
max_caption_length: int = None
tags: dict[Tag, None] = field(factory=dict, converter=safe_set)
batch_id: str = None
# Options
multiply: float = None
@ -70,6 +68,7 @@ class ImageConfig:
cond_dropout=overlay(other.cond_dropout, self.cond_dropout),
flip_p=overlay(other.flip_p, self.flip_p),
shuffle_tags=overlay(other.shuffle_tags, self.shuffle_tags),
batch_id=overlay(other.batch_id, self.batch_id)
)
@classmethod
@ -84,6 +83,7 @@ class ImageConfig:
cond_dropout=data.get("cond_dropout"),
flip_p=data.get("flip_p"),
shuffle_tags=data.get("shuffle_tags"),
batch_id=data.get("batch_id")
)
# Alternatively parse from dedicated `caption` attribute
@ -168,6 +168,8 @@ class Dataset:
cfgs.append(ImageConfig.from_file(fileset['local.yaml']))
if 'local.yml' in fileset:
cfgs.append(ImageConfig.from_file(fileset['local.yml']))
if 'batch_id.txt' in fileset:
cfgs.append(ImageConfig(batch_id=read_text(fileset['batch_id.txt'])))
result = ImageConfig.fold(cfgs)
if 'shuffle_tags.txt' in fileset:
@ -262,6 +264,7 @@ class Dataset:
multiplier=config.multiply or 1.0,
cond_dropout=config.cond_dropout,
shuffle_tags=config.shuffle_tags,
batch_id=config.batch_id
)
items.append(item)
except Exception as e:

View File

@ -124,7 +124,7 @@ class ImageTrainItem:
flip_p: probability of flipping image (0.0 to 1.0)
rating: the relative rating of the images. The rating is measured in comparison to the other images.
"""
def __init__(self,
def __init__(self,
image: PIL.Image,
caption: ImageCaption,
aspects: list[float],
@ -133,6 +133,7 @@ class ImageTrainItem:
multiplier: float=1.0,
cond_dropout=None,
shuffle_tags=False,
batch_id: str=None
):
self.caption = caption
self.aspects = aspects
@ -143,6 +144,8 @@ class ImageTrainItem:
self.multiplier = multiplier
self.cond_dropout = cond_dropout
self.shuffle_tags = shuffle_tags
self.batch_id = batch_id or DEFAULT_BATCH_ID
self.target_wh = None
self.image_size = None
if image is None or len(image) == 0:
@ -351,3 +354,6 @@ class ImageTrainItem:
image = image.crop((x_crop, y_crop, x_crop + min_xy, y_crop + min_xy))
return image
DEFAULT_BATCH_ID = "default_batch"

View File

@ -151,6 +151,17 @@ Test results: https://huggingface.co/panopstor/ff7r-stable-diffusion/blob/main/z
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time.
## Keeping images together (custom batching)
If you have a subset of your dataset that expresses the same style or concept, training quality may be improved by putting all of these images through the trainer together in the same batch or batches, instead of the default behaviour (which is to shuffle them randomly throughout the entire dataset).
To control this, put a file called `batch_id.txt` into a folder to give a unique name to the training data in this folder. For example, if you have a bunch of images of horses and you are trying to train them as a single concept, you can assign a unique name such as "my_horses" to these images by putting the word `my_horses` inside `batch_id.txt` in your folder with horse images.
> Note that because this creates extra aspect ratio buckets, you need to be very careful about correlating the number of images to your training batch size. Aim to have an exact multiple of `batch_size` images at each aspect ratio. For example, if your `batch_size` is 6 and you have images with aspect ratios 4:3, 3:4, and 9:16, you should add or delete images until you have an exact multiple of 6 images (i.e. 6, 12, 28, ...) for each aspect ratio. If you do not do this, the bucketer will duplicate images to fill up each aspect ratio bucket. You'll probably also want to use manual validation to prevent the validator from messing this up, too.
If you are using `.yaml` files for captioning, you can alternatively add a `batch_id: ` entry to either `local.yaml` or the individual images' `.yaml` files. Note that neither `.yaml` nor `batch_id.txt` files act recursively: they do not apply to subfolders.
# Stuff you probably don't need to mess with, but well here it is:

View File

@ -48,6 +48,18 @@ Double check your python version again after setup by running these two commands
Again, this should show 3.10.x
## Docker container
## Local docker container
`docker run -it -p 8888:8888 -p 6006:6006 --gpus all -e JUPYTER_PASSWORD=test1234 -t ghcr.io/victorchall/everydream2trainer:nightly`
```sh
docker compose up
```
And you can either get a shell via:
```sh
docker exec -it everydream2trainer-docker-everydream2trainer-1 /bin/bash
```
Or go to your browser and hit `http://localhost:8888`. The web password is
`test1234` but you can change that in `docker-compose.yml`.
Your current source directory will be moutned to the Jupyter notebook.

19
docker-compose.yml Normal file
View File

@ -0,0 +1,19 @@
version: '3.8'
services:
everydream2trainer:
image: ghcr.io/victorchall/everydream2trainer:nightly
ports:
- "127.0.0.1:8888:8888"
- "127.0.0.1:6006:6006"
environment:
- JUPYTER_PASSWORD=test1234
volumes:
- .:/workspace/EveryDream2trainer
- ./.cache:/root/.cache
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]

View File

@ -87,10 +87,8 @@ ARG CACHEBUST=1
RUN git clone https://github.com/victorchall/EveryDream2trainer
WORKDIR /workspace/EveryDream2trainer
RUN python utils/get_yamls.py && \
mkdir -p logs && mkdir -p input
ADD welcome.txt /
ADD start.sh /
RUN chmod +x /start.sh
CMD [ "/start.sh" ]
CMD [ "/start.sh" ]

View File

@ -2,6 +2,12 @@
cat /welcome.txt
export PYTHONUNBUFFERED=1
if [[ ! -f "v2-inference-v.yaml" ]]; then
python utils/get_yamls.py
fi
mkdir -p logs input
# RunPod SSH
if [[ -v "PUBLIC_KEY" ]] && [[ ! -d "${HOME}/.ssh" ]]
then

View File

@ -303,6 +303,15 @@ class EveryDreamOptimizer():
)
elif optimizer_name == "adamw":
opt_class = torch.optim.AdamW
if "dowg" in optimizer_name:
# coordinate_dowg, scalar_dowg require no additional parameters. Epsilon is overrideable but is unnecessary in all stable diffusion training situations.
import dowg
if optimizer_name == "coordinate_dowg":
opt_class = dowg.CoordinateDoWG
elif optimizer_name == "scalar_dowg":
opt_class = dowg.ScalarDoWG
else:
raise ValueError(f"Unknown DoWG optimizer {optimizer_name}. Available options are coordinate_dowg and scalar_dowg")
elif optimizer_name in ["dadapt_adam", "dadapt_lion", "dadapt_sgd"]:
import dadaptation

View File

@ -13,6 +13,7 @@ xformers==0.0.20
pytorch-lightning==1.6.5
OmegaConf==2.2.3
numpy==1.23.5
dowg
lion-pytorch
compel~=1.1.3
OmegaConf==2.2.3

View File

@ -0,0 +1,81 @@
import unittest
from utils.first_fit_decreasing import first_fit_decreasing
class TestFirstFitDecreasing(unittest.TestCase):
def test_single_basic(self):
input = [[1, 2, 3, 4, 5, 6]]
output = first_fit_decreasing(input, batch_size=2)
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
input = [[1, 2, 3, 4, 5, 6]]
output = first_fit_decreasing(input, batch_size=3)
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
input = [[1, 2, 3, 4, 5, 6]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 2, 3, 4, 5, 6])
input = [[1, 2, 3]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 2, 3])
def test_multi_basic(self):
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=2)
self.assertEqual(output, [1, 1, 1, 1, 2, 2])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=3)
self.assertEqual(output, [1, 1, 1, 2, 2, 1])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 1, 1, 1, 2, 2])
input = [[1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [2, 2, 1, 1])
def test_multi_complex(self):
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=2)
self.assertEqual(output, [1, 1, 3, 3, 1, 1, 2, 2, 3])
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=3)
self.assertEqual(output, [1, 1, 1, 3, 3, 3, 2, 2, 1])
input = [[1, 1, 1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 1, 1, 1, 3, 3, 3, 2, 2])
input = [[1, 1], [2, 2], [3, 3, 3]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [3, 3, 3, 2, 1, 1, 2])
input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]]
output = first_fit_decreasing(input, batch_size=4)
self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 3, 2, 2, 2, 3, 5, 5, 3])
def test_filler_bucket(self):
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=2, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=3, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 2, 2, 9, 1, 9])
input = [[1, 1, 1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 1, 2, 2, 9, 9])
input = [[1, 1], [2, 2]]
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
self.assertEqual(output, [2, 2, 9, 9, 1, 1])
input = [[1, 1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5]]
output = first_fit_decreasing(input, batch_size=4, filler_items=[9, 9])
self.assertEqual(output, [1, 1, 1, 1, 4, 4, 4, 9, 3, 3, 3, 9, 2, 2, 2, 5, 5])

View File

@ -55,7 +55,7 @@ from data.data_loader import DataLoaderMultiAspect
from data.every_dream import EveryDreamBatch, build_torch_dataloader
from data.every_dream_validation import EveryDreamValidator
from data.image_train_item import ImageTrainItem
from data.image_train_item import ImageTrainItem, DEFAULT_BATCH_ID
from utils.huggingface_downloader import try_download_model_from_hf
from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng
@ -297,19 +297,23 @@ def report_image_train_item_problems(log_folder: str, items: list[ImageTrainItem
# at a dupe ratio 1.0, all images in this bucket have effective multiplier 2.0
warn_bucket_dupe_ratio = 0.5
ar_buckets = set([tuple(i.target_wh) for i in items])
def make_bucket_key(item):
return (item.batch_id, int(item.target_wh[0]), int(item.target_wh[1]))
ar_buckets = set(make_bucket_key(i) for i in items)
for ar_bucket in ar_buckets:
count = len([i for i in items if tuple(i.target_wh) == ar_bucket])
count = len([i for i in items if make_bucket_key(i) == ar_bucket])
runt_size = batch_size - (count % batch_size)
bucket_dupe_ratio = runt_size / count
if bucket_dupe_ratio > warn_bucket_dupe_ratio:
aspect_ratio_rational = aspects.get_rational_aspect_ratio(ar_bucket)
aspect_ratio_rational = aspects.get_rational_aspect_ratio((ar_bucket[1], ar_bucket[2]))
aspect_ratio_description = f"{aspect_ratio_rational[0]}:{aspect_ratio_rational[1]}"
batch_id_description = "" if ar_bucket[0] == DEFAULT_BATCH_ID else f" for batch id '{ar_bucket[0]}'"
effective_multiplier = round(1 + bucket_dupe_ratio, 1)
logging.warning(f" * {Fore.LIGHTRED_EX}Aspect ratio bucket {ar_bucket} has only {count} "
f"images{Style.RESET_ALL}. At batch size {batch_size} this makes for an effective multiplier "
f"of {effective_multiplier}, which may cause problems. Consider adding {runt_size} or "
f"more images for aspect ratio {aspect_ratio_description}, or reducing your batch_size.")
f"more images with aspect ratio {aspect_ratio_description}{batch_id_description}, or reducing your batch_size.")
def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
logging.info(f"* DLMA resolution {args.resolution}, buckets: {args.aspects}")
@ -548,6 +552,7 @@ def main(args):
image_train_items=image_train_items,
seed=seed,
batch_size=args.batch_size,
grad_accum=args.grad_accum
)
train_batch = EveryDreamBatch(
@ -785,15 +790,15 @@ def main(args):
lr_textenc = ed_optimizer.get_textenc_lr()
loss_log_step = []
log_writer.add_scalar(tag="hyperparamater/lr unet", scalar_value=lr_unet, global_step=global_step)
log_writer.add_scalar(tag="hyperparamater/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
log_writer.add_scalar(tag="hyperparameter/lr unet", scalar_value=lr_unet, global_step=global_step)
log_writer.add_scalar(tag="hyperparameter/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step)
images_per_sec_log_step = []
if args.amp:
log_writer.add_scalar(tag="hyperparamater/grad scale", scalar_value=ed_optimizer.get_scale(), global_step=global_step)
log_writer.add_scalar(tag="hyperparameter/grad scale", scalar_value=ed_optimizer.get_scale(), global_step=global_step)
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
logs = {"loss/log_step": loss_local, "lr_unet": lr_unet, "lr_te": lr_textenc, "img/s": images_per_sec}

View File

@ -0,0 +1,63 @@
import copy
from typing import List
def first_fit_decreasing(input_items: List[List], batch_size: int, filler_items: List=[]) -> List:
"""
Given as input a list of lists, batch the items so that as much as possible the members of each of the original
lists end up in the same batch. Pad out too-short batches by taking items from the filler_items list, if available.
@return flattened list of all items in input_items and filler_items, arranged such that, as much as possible, items
that are in the same input list end up in the same batch.
"""
def sort_by_length(items: List[List]) -> List[List]:
return sorted(items, key=lambda x: len(x))
remaining = input_items
output = []
while remaining:
remaining = sort_by_length(remaining)
longest = remaining.pop()
if len(longest) == 0:
continue
if len(longest) >= batch_size:
output.append(longest[0:batch_size])
del longest[0:batch_size]
if len(longest)>0:
remaining.append(longest)
else:
# need to build this chunk by combining multiple
combined = longest
while True:
fill_length = batch_size - len(combined)
if fill_length == 0:
break
if len(remaining) == 0 and len(filler_items) == 0:
break
from_filler_bucket = filler_items[0:fill_length]
if len(from_filler_bucket) > 0:
del filler_items[0:fill_length]
combined.extend(from_filler_bucket)
continue
filler = next((r for r in remaining if len(r) <= fill_length), None)
if filler is not None:
remaining.remove(filler)
combined.extend(filler)
else:
# steal from the next longest
next_longest = remaining.pop()
combined.extend(next_longest[0:fill_length])
del next_longest[0:fill_length]
if len(next_longest) > 0:
remaining.append(next_longest)
output.append(combined)
output.append(filler_items)
return [i for o in output for i in o]