Build working version of xformers for A5000/3090s

This commit is contained in:
Augusto de la Torre 2023-02-03 00:10:07 +01:00
parent 246b57c3c3
commit 106f208fa6
1 changed files with 25 additions and 15 deletions

View File

@ -88,7 +88,7 @@
"source": [
"## Install dependencies\n",
"\n",
"**This will take a couple minutes. Wait until it says \"DONE\" to move on.** \n",
"**This will take up to 15 minutes (if building xformers). Wait until it says \"DONE\" to move on.** \n",
"You can ignore \"warnings.\""
]
},
@ -115,8 +115,21 @@
"!pip install wandb==0.13.6\n",
"!pip install colorama==0.4.6\n",
"!pip install -U triton\n",
"!pip install -U ninja\n",
"\n",
"!pip install --pre -U xformers"
"from subprocess import getoutput\n",
"s = getoutput('nvidia-smi')\n",
"\n",
"if \"A100\" in s:\n",
" print(\"Detected A100, installing stable xformers\")\n",
" !pip install -U xformers\n",
"else:\n",
" # A5000/3090/4090 support requires us to build xformers ourselves for now\n",
" print(\"Building xformers for SM86\")\n",
" !apt-get update && apt-get install -y gcc g++\n",
" !export TORCH_CUDA_ARCH_LIST=8.6 && pip install git+https://github.com/facebookresearch/xformers.git@48a77cc#egg=xformers\n",
"\n",
"print(\"DONE\")"
]
},
{
@ -217,20 +230,19 @@
"metadata": {},
"outputs": [],
"source": [
"!python train.py --project_name \"ft_v1a_512_1e6\" \\\n",
"!python train.py --project_name \"ft_v1a_512_15e7\" \\\n",
"--resume_ckpt \"{ckpt_name}\" \\\n",
"--data_root \"input\" \\\n",
"--resolution 512 \\\n",
"--batch_size 4 \\\n",
"--max_epochs 30 \\\n",
"--save_every_n_epochs 5 \\\n",
"--lr 1e-6 \\\n",
"--max_epochs 200 \\\n",
"--save_every_n_epochs 25 \\\n",
"--lr 1.5e-6 \\\n",
"--lr_scheduler constant \\\n",
"--sample_steps 50 \\\n",
"--useadam8bit \\\n",
"--save_full_precision \\\n",
"--shuffle_tags \\\n",
"--notebook \\\n",
"--amp \\\n",
"--write_schedule"
]
@ -250,21 +262,19 @@
"metadata": {},
"outputs": [],
"source": [
"\n",
"!python train.py --project_name \"ft_v1b_512_7e7\" \\\n",
"!python train.py --project_name \"ft_v1b_512_07e7\" \\\n",
"--resume_ckpt findlast \\\n",
"--data_root \"input\" \\\n",
"--resolution 512 \\\n",
"--batch_size 4 \\\n",
"--max_epochs 10 \\\n",
"--save_every_n_epochs 3 \\\n",
"--max_epochs 50 \\\n",
"--save_every_n_epochs 25 \\\n",
"--lr 0.7e-6 \\\n",
"--lr_scheduler constant \\\n",
"--sample_steps 50 \\\n",
"--useadam8bit \\\n",
"--save_full_precision \\\n",
"--shuffle_tags \\\n",
"--notebook \\\n",
"--amp \\\n",
"--write_schedule\n"
]
@ -311,7 +321,7 @@
"# fill in these three fields:\n",
"hfusername = \"MyHfUser\"\n",
"reponame = \"MyRepo\"\n",
"ckpt_name = \"f_v1-ep10-gs02500.ckpt\"\n",
"ckpt_name = \"ft_v1b_512_15e7-ep200-gs02500.ckpt\"\n",
"\n",
"\n",
"target_name = ckpt_name.replace('-','').replace('=','')\n",
@ -338,7 +348,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -352,7 +362,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.10"
"version": "3.10.9"
},
"vscode": {
"interpreter": {