From 106f208fa6130e95f2ef26e5a16ac8b107856ef1 Mon Sep 17 00:00:00 2001 From: Augusto de la Torre Date: Fri, 3 Feb 2023 00:10:07 +0100 Subject: [PATCH] Build working version of xformers for A5000/3090s --- Train_Runpod.ipynb | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/Train_Runpod.ipynb b/Train_Runpod.ipynb index 70ac476..840245b 100644 --- a/Train_Runpod.ipynb +++ b/Train_Runpod.ipynb @@ -88,7 +88,7 @@ "source": [ "## Install dependencies\n", "\n", - "**This will take a couple minutes. Wait until it says \"DONE\" to move on.** \n", + "**This will take up to 15 minutes (if building xformers). Wait until it says \"DONE\" to move on.** \n", "You can ignore \"warnings.\"" ] }, @@ -115,8 +115,21 @@ "!pip install wandb==0.13.6\n", "!pip install colorama==0.4.6\n", "!pip install -U triton\n", + "!pip install -U ninja\n", "\n", - "!pip install --pre -U xformers" + "from subprocess import getoutput\n", + "s = getoutput('nvidia-smi')\n", + "\n", + "if \"A100\" in s:\n", + " print(\"Detected A100, installing stable xformers\")\n", + " !pip install -U xformers\n", + "else:\n", + " # A5000/3090/4090 support requires us to build xformers ourselves for now\n", + " print(\"Building xformers for SM86\")\n", + " !apt-get update && apt-get install -y gcc g++\n", + " !export TORCH_CUDA_ARCH_LIST=8.6 && pip install git+https://github.com/facebookresearch/xformers.git@48a77cc#egg=xformers\n", + "\n", + "print(\"DONE\")" ] }, { @@ -217,20 +230,19 @@ "metadata": {}, "outputs": [], "source": [ - "!python train.py --project_name \"ft_v1a_512_1e6\" \\\n", + "!python train.py --project_name \"ft_v1a_512_15e7\" \\\n", "--resume_ckpt \"{ckpt_name}\" \\\n", "--data_root \"input\" \\\n", "--resolution 512 \\\n", "--batch_size 4 \\\n", - "--max_epochs 30 \\\n", - "--save_every_n_epochs 5 \\\n", - "--lr 1e-6 \\\n", + "--max_epochs 200 \\\n", + "--save_every_n_epochs 25 \\\n", + "--lr 1.5e-6 \\\n", "--lr_scheduler constant \\\n", "--sample_steps 50 \\\n", "--useadam8bit \\\n", "--save_full_precision \\\n", "--shuffle_tags \\\n", - "--notebook \\\n", "--amp \\\n", "--write_schedule" ] @@ -250,21 +262,19 @@ "metadata": {}, "outputs": [], "source": [ - "\n", - "!python train.py --project_name \"ft_v1b_512_7e7\" \\\n", + "!python train.py --project_name \"ft_v1b_512_07e7\" \\\n", "--resume_ckpt findlast \\\n", "--data_root \"input\" \\\n", "--resolution 512 \\\n", "--batch_size 4 \\\n", - "--max_epochs 10 \\\n", - "--save_every_n_epochs 3 \\\n", + "--max_epochs 50 \\\n", + "--save_every_n_epochs 25 \\\n", "--lr 0.7e-6 \\\n", "--lr_scheduler constant \\\n", "--sample_steps 50 \\\n", "--useadam8bit \\\n", "--save_full_precision \\\n", "--shuffle_tags \\\n", - "--notebook \\\n", "--amp \\\n", "--write_schedule\n" ] @@ -311,7 +321,7 @@ "# fill in these three fields:\n", "hfusername = \"MyHfUser\"\n", "reponame = \"MyRepo\"\n", - "ckpt_name = \"f_v1-ep10-gs02500.ckpt\"\n", + "ckpt_name = \"ft_v1b_512_15e7-ep200-gs02500.ckpt\"\n", "\n", "\n", "target_name = ckpt_name.replace('-','').replace('=','')\n", @@ -338,7 +348,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -352,7 +362,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.10" + "version": "3.10.9" }, "vscode": { "interpreter": {