enable float16 for older cards, t4, etc

This commit is contained in:
Victor Hall 2023-07-03 15:24:58 -04:00
parent 5d0f53646b
commit a75530471e
2 changed files with 22 additions and 3 deletions

View File

@ -6,6 +6,9 @@
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Open-flamingo Captioning\n", "# Open-flamingo Captioning\n",
"This notebook is an implementation of [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) for image captioning. \n",
"\n",
"This will require HIGH RAM shape on Google Colab, but T4 16gb is enough to run the 3B model. 9B model requires 24GB GPU or better.\n",
"\n", "\n",
"1. Read [Docs](doc/CAPTION.md) for basic usage guide. \n", "1. Read [Docs](doc/CAPTION.md) for basic usage guide. \n",
"2. Open in [Google Colab](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/CaptionFL.ipynb) **OR** Runpod/Vast using the EveryDream2trainer docker container/template and open this notebook.\n", "2. Open in [Google Colab](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/CaptionFL.ipynb) **OR** Runpod/Vast using the EveryDream2trainer docker container/template and open this notebook.\n",
@ -22,7 +25,9 @@
"# install dependencies\n", "# install dependencies\n",
"!pip install open-flamingo==2.0.0\n", "!pip install open-flamingo==2.0.0\n",
"!pip install huggingface-hub==0.15.1\n", "!pip install huggingface-hub==0.15.1\n",
"!pip install transformers==4.30.2" "!pip install transformers==4.30.2\n",
"!pip install pynvml\n",
"!pip install colorama"
] ]
}, },
{ {
@ -33,7 +38,8 @@
"source": [ "source": [
"# Colab only setup (do NOT run for docker/runpod/vast)\n", "# Colab only setup (do NOT run for docker/runpod/vast)\n",
"!git clone https://github.com/victorchall/EveryDream2trainer\n", "!git clone https://github.com/victorchall/EveryDream2trainer\n",
"%cd EveryDream2trainer" "%cd EveryDream2trainer\n",
"%mkdir -p /content/EveryDream2trainer/input"
] ]
}, },
{ {
@ -42,6 +48,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"%cd /content/EveryDream2trainer\n",
"#@markdown Optional: Extract all TAR and ZIP files in the input folder (so you can just upload a large TAR/ZIP)\n", "#@markdown Optional: Extract all TAR and ZIP files in the input folder (so you can just upload a large TAR/ZIP)\n",
"import os\n", "import os\n",
"import zipfile\n", "import zipfile\n",
@ -84,6 +91,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# 24GB GPU, 9b model\n", "# 24GB GPU, 9b model\n",
"%cd /content/EveryDream2trainer\n",
"%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 3 --model \"openflamingo/OpenFlamingo-9B-vitl-mpt7b\"" "%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 3 --model \"openflamingo/OpenFlamingo-9B-vitl-mpt7b\""
] ]
}, },
@ -94,6 +102,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# 16GB GPU, 3b model\n", "# 16GB GPU, 3b model\n",
"%cd /content/EveryDream2trainer\n",
"%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 8 --model \"openflamingo/OpenFlamingo-3B-vitl-mpt1b\"" "%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 8 --model \"openflamingo/OpenFlamingo-3B-vitl-mpt1b\""
] ]
} }

View File

@ -70,9 +70,19 @@ def get_examples(example_root, image_processor):
print(f" ** Example: {x[0]}") print(f" ** Example: {x[0]}")
return examples return examples
def get_dtype_for_cuda_device(device):
# check compute capability
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] >= 8:
dtype = torch.bfloat16
else:
dtype = torch.float16
return dtype
def main(args): def main(args):
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu" device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32 dtype = get_dtype_for_cuda_device() if device == "cuda" else torch.float32
if args.prompt: if args.prompt:
prompt = args.prompt prompt = args.prompt