enable float16 for older cards, t4, etc

This commit is contained in:
Victor Hall 2023-07-03 15:24:58 -04:00
parent 5d0f53646b
commit a75530471e
2 changed files with 22 additions and 3 deletions

View File

@ -6,6 +6,9 @@
"metadata": {},
"source": [
"# Open-flamingo Captioning\n",
"This notebook is an implementation of [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) for image captioning. \n",
"\n",
"This will require HIGH RAM shape on Google Colab, but T4 16gb is enough to run the 3B model. 9B model requires 24GB GPU or better.\n",
"\n",
"1. Read [Docs](doc/CAPTION.md) for basic usage guide. \n",
"2. Open in [Google Colab](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/CaptionFL.ipynb) **OR** Runpod/Vast using the EveryDream2trainer docker container/template and open this notebook.\n",
@ -22,7 +25,9 @@
"# install dependencies\n",
"!pip install open-flamingo==2.0.0\n",
"!pip install huggingface-hub==0.15.1\n",
"!pip install transformers==4.30.2"
"!pip install transformers==4.30.2\n",
"!pip install pynvml\n",
"!pip install colorama"
]
},
{
@ -33,7 +38,8 @@
"source": [
"# Colab only setup (do NOT run for docker/runpod/vast)\n",
"!git clone https://github.com/victorchall/EveryDream2trainer\n",
"%cd EveryDream2trainer"
"%cd EveryDream2trainer\n",
"%mkdir -p /content/EveryDream2trainer/input"
]
},
{
@ -42,6 +48,7 @@
"metadata": {},
"outputs": [],
"source": [
"%cd /content/EveryDream2trainer\n",
"#@markdown Optional: Extract all TAR and ZIP files in the input folder (so you can just upload a large TAR/ZIP)\n",
"import os\n",
"import zipfile\n",
@ -84,6 +91,7 @@
"outputs": [],
"source": [
"# 24GB GPU, 9b model\n",
"%cd /content/EveryDream2trainer\n",
"%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 3 --model \"openflamingo/OpenFlamingo-9B-vitl-mpt7b\""
]
},
@ -94,6 +102,7 @@
"outputs": [],
"source": [
"# 16GB GPU, 3b model\n",
"%cd /content/EveryDream2trainer\n",
"%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 8 --model \"openflamingo/OpenFlamingo-3B-vitl-mpt1b\""
]
}

View File

@ -70,9 +70,19 @@ def get_examples(example_root, image_processor):
print(f" ** Example: {x[0]}")
return examples
def get_dtype_for_cuda_device(device):
# check compute capability
compute_capability = torch.cuda.get_device_capability()
if compute_capability[0] >= 8:
dtype = torch.bfloat16
else:
dtype = torch.float16
return dtype
def main(args):
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
dtype = get_dtype_for_cuda_device() if device == "cuda" else torch.float32
if args.prompt:
prompt = args.prompt