diff --git a/Train_Colab.ipynb b/Train_Colab.ipynb
index bb5993e..6d22200 100644
--- a/Train_Colab.ipynb
+++ b/Train_Colab.ipynb
@@ -16,11 +16,47 @@
"id": "blaLMSbkPHhG"
},
"source": [
- "# EveryDream2 Colab Edition\n",
+ "
\n",
+ " \n",
+ "
\n",
"\n",
- "Check out documentation here: https://github.com/victorchall/EveryDream2trainer#docs\n",
+ "
\n",
"\n",
- "And join the discord: https://discord.gg/uheqxU6sXN"
+ "---\n",
+ "\n",
+ "\n",
+ " Colab Edition\n",
+ "
\n",
+ "\n",
+ "---\n",
+ "\n",
+ "
\n",
+ "\n",
+ "Check out the **EveryDream2trainer** documentation and runpod/vastai and local setups here: \n",
+ "\n",
+ "[📚 **Documentation**](https://github.com/victorchall/EveryDream2trainer#docs)\n",
+ "\n",
+ "And join our vibrant community on Discord:\n",
+ "\n",
+ "[💬 **Join the Discord**](https://discord.gg/uheqxU6sXN)\n",
+ "\n",
+ "If you find this tool useful, please consider subscribing to the project on Patreon or making a one-time donation on Ko-fi. Your donations keep this project alive as a free open-source tool with ongoing enhancements.\n",
+ "\n",
+ "
\n",
+ "\n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n",
+ "\n",
+ "
\n",
+ "\n",
+ "\n",
+ " \n",
+ " \n",
+ " \n",
+ "
\n"
]
},
{
@@ -33,18 +69,26 @@
"outputs": [],
"source": [
"#@markdown # Setup and Install Dependencies\n",
- "from IPython.display import clear_output\n",
+ "from IPython.display import clear_output, display, HTML\n",
"import subprocess\n",
- "from tqdm.auto import tqdm\n",
"import time\n",
"import os \n",
+ "from tqdm.auto import tqdm\n",
+ "import PIL\n",
+ "\n",
+ "# Defining function for colored text\n",
+ "def colored(r, g, b, text):\n",
+ " return f\"\\033[38;2;{r};{g};{b}m{text} \\033[38;2;255;255;255m\"\n",
"\n",
"#@markdown Optional connect Gdrive But strongly recommended\n",
- "#@markdown This will let you put all your training data and checkpoints directly on your drive. Much faster/easier to continue later, less setup time.\n",
+ "#@markdown This will let you put all your training data and checkpoints directly on your drive. \n",
+ "#@markdown Much faster/easier to continue later, less setup time.\n",
"\n",
"#@markdown Creates /content/drive/MyDrive/everydreamlogs/ckpt\n",
"Mount_to_Gdrive = True #@param{type:\"boolean\"} \n",
"\n",
+ "# Clone the git repository\n",
+ "print(colored(0, 255, 0, 'Cloning git repository...'))\n",
"!git clone https://github.com/victorchall/EveryDream2trainer.git\n",
"\n",
"if Mount_to_Gdrive:\n",
@@ -55,11 +99,10 @@
"\n",
"%cd /content/EveryDream2trainer\n",
"\n",
+ "# Download Arial font\n",
+ "print(colored(0, 255, 0, 'Downloading Arial font...'))\n",
"!wget -O arial.ttf https://raw.githubusercontent.com/matomo-org/travis-scripts/master/fonts/Arial.ttf\n",
"\n",
- "!cp /content/arial.ttf /usr/share/fonts/truetype/\n",
- "\n",
- "\n",
"packages = [\n",
" 'transformers==4.27.1',\n",
" 'diffusers[torch]==0.14.0',\n",
@@ -71,7 +114,7 @@
" 'protobuf==3.20.3',\n",
" 'wandb==0.13.6',\n",
" 'pyre-extensions==0.0.23',\n",
- " '--no-deps xformers==0.0.19',\n",
+ " 'xformers==0.0.20',\n",
" 'pytorch-lightning==1.9.2',\n",
" 'OmegaConf==2.2.3',\n",
" 'wandb',\n",
@@ -80,10 +123,11 @@
" 'lion-pytorch'\n",
"]\n",
"\n",
- "for package in tqdm(packages, desc='Installing packages', unit='package'):\n",
+ "print(colored(0, 255, 0, 'Installing packages...'))\n",
+ "for package in tqdm(packages, desc='Installing packages', unit='package', bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt}'):\n",
" if isinstance(package, tuple):\n",
" package_name, extra_index_url = package\n",
- " cmd = f\"pip install -q {package_name} --extra-index-url {extra_index_url}\"\n",
+ " cmd = f\"pip install -I -q {package_name} --extra-index-url {extra_index_url}\"\n",
" else:\n",
" cmd = f\"pip install -q {package}\"\n",
" \n",
@@ -92,27 +136,61 @@
"clear_output()\n",
"\n",
"\n",
- "\n",
+ "# Execute Python script\n",
+ "print(colored(0, 255, 0, 'Executing Python script...'))\n",
"!python utils/get_yamls.py\n",
"clear_output()\n",
- "## ty Google for cutting out install time by 50%\n",
- "print(\"DONE! installing dependencies.\")\n",
- "GPU = !nvidia-smi\n",
- "print(\"GPU details:\")\n",
- "for line in GPU:\n",
- " print(line)\n",
+ "\n",
+ "print(colored(0, 255, 0, \"DONE! installing dependencies.\"))\n",
+ "\n",
+ "# Import pynvml and get GPU details\n",
+ "import pynvml\n",
+ "\n",
+ "pynvml.nvmlInit()\n",
+ "\n",
+ "handle = pynvml.nvmlDeviceGetHandleByIndex(0)\n",
+ "\n",
+ "gpu_name = pynvml.nvmlDeviceGetName(handle)\n",
+ "gpu_memory = pynvml.nvmlDeviceGetMemoryInfo(handle).total / 1024**3\n",
+ "cuda_version_number = pynvml.nvmlSystemGetCudaDriverVersion_v2()\n",
+ "cuda_version_major = cuda_version_number // 1000\n",
+ "cuda_version_minor = (cuda_version_number % 1000) // 10\n",
+ "cuda_version = f\"{cuda_version_major}.{cuda_version_minor}\"\n",
+ "\n",
+ "pynvml.nvmlShutdown()\n",
+ "\n",
"Python_version = !python --version\n",
- "print(\"\\nPython version:\")\n",
- "print(Python_version[0])\n",
"import torch\n",
- "print(\"\\nPyTorch version:\")\n",
- "print(torch.__version__)\n",
"import torchvision\n",
- "print(\"\\nTorchvision version:\")\n",
- "print(torchvision.__version__)\n",
"import xformers\n",
- "print(\"\\nXFormers version:\")\n",
- "print(xformers.__version__)\n",
+ "\n",
+ "display(HTML(f\"\"\"\n",
+ "\n",
+ " \n",
+ " Python version: | \n",
+ " {Python_version[0]} | \n",
+ " GPU Name: | \n",
+ " {gpu_name} | \n",
+ "
\n",
+ " \n",
+ " PyTorch version: | \n",
+ " {torch.__version__} | \n",
+ " GPU Memory (GB): | \n",
+ " {gpu_memory:.2f} | \n",
+ "
\n",
+ " \n",
+ " Torchvision version: | \n",
+ " {torchvision.__version__} | \n",
+ " CUDA version: | \n",
+ " {cuda_version} | \n",
+ "
\n",
+ " \n",
+ " XFormers version: | \n",
+ " {xformers.__version__} | \n",
+ "
\n",
+ "
\n",
+ "\"\"\"))\n",
+ "\n",
"time.sleep(2)\n"
]
},
@@ -120,49 +198,50 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
- "id": "unaffeqGP_0A",
- "cellView": "form"
+ "cellView": "form",
+ "id": "unaffeqGP_0A"
},
"outputs": [],
"source": [
"#@title Get A Base Model\n",
"#@markdown Choose SD1.5, Waifu Diffusion 1.3, SD2.1, or 2.1(512) from the dropdown, or paste your own URL in the box\n",
- "#@markdown * Alternately you can link to a HF repo using NAME/MODEL, this does not save to your Gdrive, if you want to save an hf model use the direct url\n",
+ "#@markdown * Alternately you can link to an HF repo using NAME/MODEL, this does not save to your Gdrive, if you want to save an HF model, use the direct URL\n",
"\n",
"#@markdown * Link to a set of diffusers on your Gdrive\n",
"\n",
- "#@markdown * Paste a url, atm there is no support for .safetensors\n",
+ "#@markdown * Paste a URL, atm there is no support for .safetensors\n",
"\n",
"from IPython.display import clear_output\n",
"!mkdir input\n",
"%cd /content/EveryDream2trainer\n",
"MODEL_LOCATION = \"sd_v1-5+vae.ckpt\" #@param [\"sd_v1-5+vae.ckpt\", \"hakurei/waifu-diffusion-v1-3\", \"stabilityai/stable-diffusion-2-1-base\", \"stabilityai/stable-diffusion-2-1\"] {allow-input: true}\n",
+ "\n",
"if MODEL_LOCATION == \"sd_v1-5+vae.ckpt\":\n",
" MODEL_LOCATION = \"panopstor/EveryDream\"\n",
- "Flag = False\n",
+ "\n",
+ "If_Ckpt = False\n",
"import os\n",
"\n",
"download_path = \"\"\n",
"\n",
- "if \".co\" in MODEL_LOCATION or \"https\" in MODEL_LOCATION or \"www\" in MODEL_LOCATION: #maybe just add a radio button to download this should work for now\n",
- " print(\"Downloading \")\n",
+ "if \".co\" in MODEL_LOCATION or \"https\" in MODEL_LOCATION or \"www\" in MODEL_LOCATION:\n",
+ " MODEL_URL = MODEL_LOCATION\n",
+ " print(\"Downloading...\")\n",
" !wget $MODEL_LOCATION\n",
" clear_output()\n",
- " print(\"DONE!\")\n",
+ " print(\"Download completed!\")\n",
" download_path = os.path.join(os.getcwd(), os.path.basename(MODEL_URL))\n",
- "\n",
"else:\n",
- " save_name = MODEL_LOCATION\n",
+ " save_name = MODEL_LOCATION\n",
"\n",
"%cd /content/EveryDream2trainer\n",
- "#@markdown * If you chose to link to a .ckpt Select the correct model version in the drop down menu for conversion\n",
"\n",
"inference_yaml = \" \"\n",
"\n",
"# Check if the downloaded or copied model is a .ckpt file\n",
"#@markdown Is the model 1.5 or 2.1 based?\n",
"if download_path.endswith(\".ckpt\") or MODEL_LOCATION.endswith(\".ckpt\"):\n",
- " Flag = True\n",
+ " If_Ckpt = True\n",
" model_type = \"SD1x\" #@param [\"SD1x\", \"SD2_512_base\", \"SD21\"]\n",
" save_path = download_path\n",
" if \".ckpt\" in save_name:\n",
@@ -171,6 +250,7 @@
" img_size = 512\n",
" upscale_attention = False\n",
" prediction_type = \"epsilon\"\n",
+ "\n",
" if model_type == \"SD1x\":\n",
" inference_yaml = \"v1-inference.yaml\"\n",
" elif model_type == \"SD2_512_base\":\n",
@@ -182,7 +262,7 @@
" inference_yaml = \"v2-inference-v.yaml\"\n",
" img_size = 768\n",
"\n",
- " !python utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
+ " !python utils/convert_original_stable_diffusion_to_diffusers.py --scheduler_type ddim \\\n",
" --original_config_file $inference_yaml \\\n",
" --image_size $img_size \\\n",
" --checkpoint_path $MODEL_LOCATION \\\n",
@@ -190,12 +270,13 @@
" --upcast_attn False \\\n",
" --dump_path $save_name\n",
"\n",
- " # Set the save path to the GDrive directory if cache_to_gdrive is True\n",
+ "# Set the save path to the GDrive directory if cache_to_gdrive is True\n",
+ "if If_Ckpt:\n",
+ " save_name = os.path.join(\"/content/drive/MyDrive/everydreamlogs/ckpt\", save_name)\n",
"\n",
- "if Flag:\n",
- " save_name = os.path.join(\"/content/drive/MyDrive/everydreamlogs/ckpt\", save_name)\n",
"if inference_yaml != \" \":\n",
- " print(\"Model saved to: \" + save_name + \". The \" + inference_yaml + \" was used!\")\n",
+ " print(\"Model saved to: \" + save_name + \". The \" + inference_yaml + \" was used!\")\n",
+ "\n",
"print(\"Model \" + save_name + \" will be used!\")\n"
]
},
@@ -302,11 +383,12 @@
"#@markdown * Location on your Gdrive where your training images are.\n",
"Dataset_Location = \"/content/drive/MyDrive/training_samples\" #@param {type:\"string\"}\n",
"\n",
- "model = save_name\n",
+ "if not resume:\n",
+ " model = save_name\n",
"\n",
"#@markdown * Max Epochs to train for, this defines how many total times all your training data is used. Default of 100 is a good start if you are training ~30-40 images of one subject. If you have 100 images, you can reduce this to 40-50 and so forth.\n",
"\n",
- "Max_Epochs = 100 #@param {type:\"slider\", min:0, max:200, step:5}\n",
+ "Max_Epochs = 200 #@param {type:\"slider\", min:0, max:200, step:1}\n",
"\n",
"#@markdown * How often to save checkpoints.\n",
"Save_every_N_epoch = 20 #@param{type:\"integer\"}\n",
@@ -361,7 +443,7 @@
"#@markdown use validation with wandb\n",
"\n",
"validatation = False #@param{type:\"boolean\"}\n",
- "\n",
+ "Hide_Warnings = False #@param {type:\"boolean\"}\n",
"\n",
"extensions = ['.zip', '.7z', '.rar', '.tgz']\n",
"uncompressed_dir = 'Training_Data'\n",
@@ -394,6 +476,11 @@
"if validatation:\n",
" validate = \"--validation_config validation_default.json\"\n",
"\n",
+ "\n",
+ "if Hide_Warnings:\n",
+ " import warnings\n",
+ " warnings.filterwarnings(\"ignore\")\n",
+ "\n",
"wandb_settings = \"\"\n",
"if wandb_token:\n",
" !rm /root/.netrc\n",
@@ -446,7 +533,7 @@
" --zero_frequency_noise_ratio $zero_frequency_noise\n",
"\n",
"# Finish the training process\n",
- "clear_output()\n",
+ "# clear_output()\n",
"time.sleep(2)\n",
"print(\"Training is complete, select a model to start training again\")\n",
"time.sleep(2)\n",
@@ -464,8 +551,8 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
- "cellView": "form",
- "id": "8HmIWtODuE6p"
+ "id": "8HmIWtODuE6p",
+ "cellView": "form"
},
"outputs": [],
"source": [
@@ -519,15 +606,21 @@
},
{
"cell_type": "markdown",
- "source": [
- "## Optional NoteBook Features, read all the documentation in /content/EveryDream2trainer/doc before proceeding."
- ],
"metadata": {
"id": "fzXLJVC6OCeP"
- }
+ },
+ "source": [
+ "## Optional NoteBook Features, read all the documentation in /content/EveryDream2trainer/doc before proceeding."
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "cellView": "form",
+ "id": "BafdWaYymg0O"
+ },
+ "outputs": [],
"source": [
"#@title Remove logs for samples when training (optional) run before training\n",
"file_path = \"/content/EveryDream2trainer/utils/sample_generator.py\"\n",
@@ -551,16 +644,15 @@
" file.write(content)\n",
"\n",
"print(\"The specified code block has been deleted.\")\n"
- ],
- "metadata": {
- "id": "BafdWaYymg0O",
- "cellView": "form"
- },
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "d20kz8EtosWM"
+ },
+ "outputs": [],
"source": [
"#@title train.json Editor { display-mode: \"form\" }\n",
"#title json Editor for ED2\n",
@@ -571,7 +663,10 @@
" \"disable_textenc_training\": False,\n",
" \"disable_xformers\": False,\n",
" \"disable_amp\": False,\n",
+ " \"lowvram\": False,\n",
+ " \"notebook\": False,\n",
" \"save_optimizer\": False,\n",
+ " \"scale_lr\": False,\n",
" \"gradient_checkpointing\": True,\n",
" \"wandb\": False,\n",
" \"write_schedule\": False,\n",
@@ -708,12 +803,7 @@
" json.dump(data, file, indent=2)\n",
"\n",
"print(f\"Modified JSON data saved to '{filename}'.\")"
- ],
- "metadata": {
- "id": "d20kz8EtosWM"
- },
- "execution_count": null,
- "outputs": []
+ ]
},
{
"cell_type": "code",
@@ -754,6 +844,7 @@
"accelerator": "GPU",
"colab": {
"provenance": [],
+ "gpuType": "T4",
"include_colab_link": true
},
"gpuClass": "standard",