Update Train_Colab.ipynb

xformers==0.0.17.dev435" is just the version i test can do just pip install xformers, but i did not test 0.0.16 yet, this mixed with the other pr's really speeds up colab
This commit is contained in:
nawnie 2023-02-01 00:42:11 -06:00 committed by GitHub
parent 246b57c3c3
commit c936d9db48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 5 deletions

View File

@ -90,7 +90,7 @@
"from IPython.display import clear_output\n", "from IPython.display import clear_output\n",
"from subprocess import getoutput\n", "from subprocess import getoutput\n",
"s = getoutput('nvidia-smi')\n", "s = getoutput('nvidia-smi')\n",
"!pip install -q torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url \"https://download.pytorch.org/whl/cu116\"\n", "!pip install -q torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url \"https://download.pytorch.org/whl/cu117\"\n",
"!pip install -q transformers==4.25.1\n", "!pip install -q transformers==4.25.1\n",
"!pip install -q diffusers[torch]==0.10.2\n", "!pip install -q diffusers[torch]==0.10.2\n",
"!pip install -q pynvml==11.4.1\n", "!pip install -q pynvml==11.4.1\n",
@ -101,10 +101,7 @@
"!pip install -q protobuf==3.20.1\n", "!pip install -q protobuf==3.20.1\n",
"!pip install -q wandb==0.13.6\n", "!pip install -q wandb==0.13.6\n",
"!pip install -q pyre-extensions==0.0.23\n", "!pip install -q pyre-extensions==0.0.23\n",
"if \"A100\" in s:\n", "!pip install -q xformers==0.0.17.dev435",
" !pip install -q https://huggingface.co/industriaditat/xformers_precompiles/resolve/main/A100_13dev/xformers-0.0.13.dev0-py3-none-any.whl\n",
"else:\n",
" !pip install -q https://huggingface.co/industriaditat/xformers_precompiles/resolve/main/T4_13dev/xformers-0.0.13.dev0-py3-none-any.whl\n",
"!pip install -q pytorch-lightning==1.6.5\n", "!pip install -q pytorch-lightning==1.6.5\n",
"!pip install -q OmegaConf==2.2.3\n", "!pip install -q OmegaConf==2.2.3\n",
"!pip install -q numpy==1.23.5\n", "!pip install -q numpy==1.23.5\n",