enable float16 for older cards, t4, etc
This commit is contained in:
parent
5d0f53646b
commit
a75530471e
|
@ -6,6 +6,9 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"# Open-flamingo Captioning\n",
|
"# Open-flamingo Captioning\n",
|
||||||
|
"This notebook is an implementation of [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) for image captioning. \n",
|
||||||
|
"\n",
|
||||||
|
"This will require HIGH RAM shape on Google Colab, but T4 16gb is enough to run the 3B model. 9B model requires 24GB GPU or better.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"1. Read [Docs](doc/CAPTION.md) for basic usage guide. \n",
|
"1. Read [Docs](doc/CAPTION.md) for basic usage guide. \n",
|
||||||
"2. Open in [Google Colab](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/CaptionFL.ipynb) **OR** Runpod/Vast using the EveryDream2trainer docker container/template and open this notebook.\n",
|
"2. Open in [Google Colab](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/CaptionFL.ipynb) **OR** Runpod/Vast using the EveryDream2trainer docker container/template and open this notebook.\n",
|
||||||
|
@ -22,7 +25,9 @@
|
||||||
"# install dependencies\n",
|
"# install dependencies\n",
|
||||||
"!pip install open-flamingo==2.0.0\n",
|
"!pip install open-flamingo==2.0.0\n",
|
||||||
"!pip install huggingface-hub==0.15.1\n",
|
"!pip install huggingface-hub==0.15.1\n",
|
||||||
"!pip install transformers==4.30.2"
|
"!pip install transformers==4.30.2\n",
|
||||||
|
"!pip install pynvml\n",
|
||||||
|
"!pip install colorama"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -33,7 +38,8 @@
|
||||||
"source": [
|
"source": [
|
||||||
"# Colab only setup (do NOT run for docker/runpod/vast)\n",
|
"# Colab only setup (do NOT run for docker/runpod/vast)\n",
|
||||||
"!git clone https://github.com/victorchall/EveryDream2trainer\n",
|
"!git clone https://github.com/victorchall/EveryDream2trainer\n",
|
||||||
"%cd EveryDream2trainer"
|
"%cd EveryDream2trainer\n",
|
||||||
|
"%mkdir -p /content/EveryDream2trainer/input"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -42,6 +48,7 @@
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"%cd /content/EveryDream2trainer\n",
|
||||||
"#@markdown Optional: Extract all TAR and ZIP files in the input folder (so you can just upload a large TAR/ZIP)\n",
|
"#@markdown Optional: Extract all TAR and ZIP files in the input folder (so you can just upload a large TAR/ZIP)\n",
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"import zipfile\n",
|
"import zipfile\n",
|
||||||
|
@ -84,6 +91,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# 24GB GPU, 9b model\n",
|
"# 24GB GPU, 9b model\n",
|
||||||
|
"%cd /content/EveryDream2trainer\n",
|
||||||
"%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 3 --model \"openflamingo/OpenFlamingo-9B-vitl-mpt7b\""
|
"%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 3 --model \"openflamingo/OpenFlamingo-9B-vitl-mpt7b\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
@ -94,6 +102,7 @@
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# 16GB GPU, 3b model\n",
|
"# 16GB GPU, 3b model\n",
|
||||||
|
"%cd /content/EveryDream2trainer\n",
|
||||||
"%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 8 --model \"openflamingo/OpenFlamingo-3B-vitl-mpt1b\""
|
"%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 8 --model \"openflamingo/OpenFlamingo-3B-vitl-mpt1b\""
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -70,9 +70,19 @@ def get_examples(example_root, image_processor):
|
||||||
print(f" ** Example: {x[0]}")
|
print(f" ** Example: {x[0]}")
|
||||||
return examples
|
return examples
|
||||||
|
|
||||||
|
def get_dtype_for_cuda_device(device):
|
||||||
|
# check compute capability
|
||||||
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
|
if compute_capability[0] >= 8:
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
else:
|
||||||
|
dtype = torch.float16
|
||||||
|
return dtype
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
|
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
|
||||||
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
dtype = get_dtype_for_cuda_device() if device == "cuda" else torch.float32
|
||||||
|
|
||||||
if args.prompt:
|
if args.prompt:
|
||||||
prompt = args.prompt
|
prompt = args.prompt
|
||||||
|
|
Loading…
Reference in New Issue