diff --git a/CaptionFL.ipynb b/CaptionFL.ipynb index 4642840..60f252b 100644 --- a/CaptionFL.ipynb +++ b/CaptionFL.ipynb @@ -6,6 +6,9 @@ "metadata": {}, "source": [ "# Open-flamingo Captioning\n", + "This notebook is an implementation of [OpenFlamingo](https://github.com/mlfoundations/open_flamingo) for image captioning. \n", + "\n", + "This will require HIGH RAM shape on Google Colab, but T4 16gb is enough to run the 3B model. 9B model requires 24GB GPU or better.\n", "\n", "1. Read [Docs](doc/CAPTION.md) for basic usage guide. \n", "2. Open in [Google Colab](https://colab.research.google.com/github/victorchall/EveryDream2trainer/blob/main/CaptionFL.ipynb) **OR** Runpod/Vast using the EveryDream2trainer docker container/template and open this notebook.\n", @@ -22,7 +25,9 @@ "# install dependencies\n", "!pip install open-flamingo==2.0.0\n", "!pip install huggingface-hub==0.15.1\n", - "!pip install transformers==4.30.2" + "!pip install transformers==4.30.2\n", + "!pip install pynvml\n", + "!pip install colorama" ] }, { @@ -33,7 +38,8 @@ "source": [ "# Colab only setup (do NOT run for docker/runpod/vast)\n", "!git clone https://github.com/victorchall/EveryDream2trainer\n", - "%cd EveryDream2trainer" + "%cd EveryDream2trainer\n", + "%mkdir -p /content/EveryDream2trainer/input" ] }, { @@ -42,6 +48,7 @@ "metadata": {}, "outputs": [], "source": [ + "%cd /content/EveryDream2trainer\n", "#@markdown Optional: Extract all TAR and ZIP files in the input folder (so you can just upload a large TAR/ZIP)\n", "import os\n", "import zipfile\n", @@ -84,6 +91,7 @@ "outputs": [], "source": [ "# 24GB GPU, 9b model\n", + "%cd /content/EveryDream2trainer\n", "%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 3 --model \"openflamingo/OpenFlamingo-9B-vitl-mpt7b\"" ] }, @@ -94,6 +102,7 @@ "outputs": [], "source": [ "# 16GB GPU, 3b model\n", + "%cd /content/EveryDream2trainer\n", "%run caption_fl.py --data_root \"input\" --min_new_tokens 20 --max_new_tokens 30 --num_beams 8 --model \"openflamingo/OpenFlamingo-3B-vitl-mpt1b\"" ] } diff --git a/caption_fl.py b/caption_fl.py index 1376631..d674821 100644 --- a/caption_fl.py +++ b/caption_fl.py @@ -70,9 +70,19 @@ def get_examples(example_root, image_processor): print(f" ** Example: {x[0]}") return examples +def get_dtype_for_cuda_device(device): + # check compute capability + compute_capability = torch.cuda.get_device_capability() + if compute_capability[0] >= 8: + dtype = torch.bfloat16 + else: + dtype = torch.float16 + return dtype + + def main(args): device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu" - dtype = torch.bfloat16 if device == "cuda" else torch.float32 + dtype = get_dtype_for_cuda_device() if device == "cuda" else torch.float32 if args.prompt: prompt = args.prompt