Merge branch 'main' of https://github.com/victorchall/EveryDream2trainer into main
This commit is contained in:
commit
fdf230634e
209
caption_fl.py
209
caption_fl.py
|
@ -1,209 +0,0 @@
|
||||||
"""
|
|
||||||
Copyright [2022-2023] Victor C Hall
|
|
||||||
|
|
||||||
Licensed under the GNU Affero General Public License;
|
|
||||||
You may not use this code except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
https://www.gnu.org/licenses/agpl-3.0.en.html
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from PIL import Image
|
|
||||||
import argparse
|
|
||||||
import requests
|
|
||||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration, GitProcessor, GitForCausalLM, AutoModel, AutoProcessor
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
from open_flamingo import create_model_and_transforms
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from pynvml import *
|
|
||||||
|
|
||||||
import time
|
|
||||||
from colorama import Fore, Style
|
|
||||||
|
|
||||||
|
|
||||||
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
|
|
||||||
|
|
||||||
def get_gpu_memory_map():
|
|
||||||
nvmlInit()
|
|
||||||
handle = nvmlDeviceGetHandleByIndex(0)
|
|
||||||
info = nvmlDeviceGetMemoryInfo(handle)
|
|
||||||
nvmlShutdown()
|
|
||||||
return info.used/1024/1024
|
|
||||||
|
|
||||||
def remove_duplicates(string):
|
|
||||||
words = string.split(', ') # Split the string into individual words
|
|
||||||
unique_words = []
|
|
||||||
|
|
||||||
for word in words:
|
|
||||||
if word not in unique_words:
|
|
||||||
unique_words.append(word)
|
|
||||||
else:
|
|
||||||
break # Stop appending words once a duplicate is found
|
|
||||||
|
|
||||||
return ', '.join(unique_words)
|
|
||||||
|
|
||||||
def get_examples(example_root, image_processor):
|
|
||||||
examples = []
|
|
||||||
for root, dirs, files in os.walk(example_root):
|
|
||||||
for file in files:
|
|
||||||
ext = os.path.splitext(file)[-1].lower()
|
|
||||||
if ext in SUPPORTED_EXT:
|
|
||||||
#get .txt file of same base name
|
|
||||||
txt_file = os.path.splitext(file)[0] + ".txt"
|
|
||||||
with open(os.path.join(root, txt_file), 'r') as f:
|
|
||||||
caption = f.read()
|
|
||||||
image = Image.open(os.path.join(root, file))
|
|
||||||
vision_x = [image_processor(image).unsqueeze(0)]
|
|
||||||
#vision_x = torch.cat(vision_x, dim=0)
|
|
||||||
#vision_x = vision_x.unsqueeze(1).unsqueeze(0)
|
|
||||||
examples.append((caption, vision_x))
|
|
||||||
for x in examples:
|
|
||||||
print(f" ** Example: {x[0]}")
|
|
||||||
return examples
|
|
||||||
|
|
||||||
def get_dtype_for_cuda_device(device):
|
|
||||||
# check compute capability
|
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
|
||||||
if compute_capability[0] >= 8:
|
|
||||||
dtype = torch.bfloat16
|
|
||||||
else:
|
|
||||||
dtype = torch.float16
|
|
||||||
return dtype
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
|
||||||
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
|
|
||||||
dtype = get_dtype_for_cuda_device(device) if device == "cuda" else torch.float32
|
|
||||||
|
|
||||||
if args.prompt:
|
|
||||||
prompt = args.prompt
|
|
||||||
else:
|
|
||||||
prompt = "<image>: "
|
|
||||||
print(f" using prompt: {prompt}")
|
|
||||||
|
|
||||||
if "mpt7b" in args.model:
|
|
||||||
lang_encoder_path="anas-awadalla/mpt-7b"
|
|
||||||
tokenizer_path="anas-awadalla/mpt-7b"
|
|
||||||
elif "mpt1b" in args.model:
|
|
||||||
lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b"
|
|
||||||
tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b"
|
|
||||||
|
|
||||||
model, image_processor, tokenizer = create_model_and_transforms(
|
|
||||||
clip_vision_encoder_path="ViT-L-14",
|
|
||||||
clip_vision_encoder_pretrained="openai",
|
|
||||||
lang_encoder_path=lang_encoder_path,
|
|
||||||
tokenizer_path=tokenizer_path,
|
|
||||||
cross_attn_every_n_layers=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer.padding_side = "left"
|
|
||||||
|
|
||||||
checkpoint_path = hf_hub_download(args.model, "checkpoint.pt")
|
|
||||||
model.load_state_dict(torch.load(checkpoint_path), strict=False)
|
|
||||||
print(f"GPU memory used, before loading model: {get_gpu_memory_map()} MB")
|
|
||||||
model.to(0, dtype=dtype)
|
|
||||||
print(f"GPU memory used, after loading model: {get_gpu_memory_map()} MB")
|
|
||||||
|
|
||||||
# examples give few shot learning for captioning the novel image
|
|
||||||
examples = get_examples(args.example_root, image_processor)
|
|
||||||
|
|
||||||
prompt = ""
|
|
||||||
output_prompt = "Output:"
|
|
||||||
per_image_prompt = "<image> " + output_prompt
|
|
||||||
|
|
||||||
for example in iter(examples):
|
|
||||||
prompt += f"{per_image_prompt}{example[0]}<|endofchunk|>"
|
|
||||||
prompt += per_image_prompt # prepare for novel example
|
|
||||||
prompt = prompt.replace("\n", "") # in case captions had newlines
|
|
||||||
print(f" \n** Final full prompt with example pairs: {prompt}")
|
|
||||||
|
|
||||||
# os.walk all files in args.data_root recursively
|
|
||||||
for root, dirs, files in os.walk(args.data_root):
|
|
||||||
for file in files:
|
|
||||||
#get file extension
|
|
||||||
ext = os.path.splitext(file)[1]
|
|
||||||
if ext.lower() in SUPPORTED_EXT:
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
full_file_path = os.path.join(root, file)
|
|
||||||
image = Image.open(full_file_path)
|
|
||||||
|
|
||||||
vision_x = [vx[1][0] for vx in examples]
|
|
||||||
vision_x.append(image_processor(image).unsqueeze(0))
|
|
||||||
vision_x = torch.cat(vision_x, dim=0)
|
|
||||||
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
|
|
||||||
vision_x = vision_x.to(device, dtype=dtype)
|
|
||||||
|
|
||||||
lang_x = tokenizer(
|
|
||||||
[prompt], # blank for image captioning
|
|
||||||
return_tensors="pt",
|
|
||||||
)
|
|
||||||
lang_x.to(device)
|
|
||||||
|
|
||||||
input_ids = lang_x["input_ids"].to(device)
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(dtype=dtype), torch.no_grad():
|
|
||||||
generated_text = model.generate(
|
|
||||||
vision_x=vision_x,
|
|
||||||
lang_x=input_ids,
|
|
||||||
attention_mask=lang_x["attention_mask"],
|
|
||||||
max_new_tokens=args.max_new_tokens,
|
|
||||||
min_new_tokens=args.min_new_tokens,
|
|
||||||
num_beams=args.num_beams,
|
|
||||||
temperature=args.temperature,
|
|
||||||
top_k=args.top_k,
|
|
||||||
top_p=args.top_p,
|
|
||||||
repetition_penalty=args.repetition_penalty,
|
|
||||||
)
|
|
||||||
del vision_x
|
|
||||||
del lang_x
|
|
||||||
|
|
||||||
# trim and clean
|
|
||||||
generated_text = tokenizer.decode(generated_text[0][len(input_ids[0]):], skip_special_tokens=True)
|
|
||||||
generated_text = generated_text.split(output_prompt)[0]
|
|
||||||
generated_text = remove_duplicates(generated_text)
|
|
||||||
|
|
||||||
exec_time = time.time() - start_time
|
|
||||||
print(f"* Caption: {generated_text}")
|
|
||||||
|
|
||||||
print(f" Time for last caption: {exec_time} sec. GPU memory used: {get_gpu_memory_map()} MB")
|
|
||||||
|
|
||||||
name = os.path.splitext(full_file_path)[0]
|
|
||||||
if not os.path.exists(name):
|
|
||||||
with open(f"{name}.txt", "w") as f:
|
|
||||||
f.write(generated_text)
|
|
||||||
print("Done!")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print(f"Available models:")
|
|
||||||
print(f" openflamingo/OpenFlamingo-9B-vitl-mpt7b (default)")
|
|
||||||
print(f" openflamingo/OpenFlamingo-3B-vitl-mpt1b")
|
|
||||||
print(f" openflamingo/OpenFlamingo-4B-vitl-rpj3b")
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument("--data_root", type=str, default="input", help="Path to images")
|
|
||||||
parser.add_argument("--example_root", type=str, default="examples", help="Path to 2-3 precaptioned images to guide generation")
|
|
||||||
parser.add_argument("--model", type=str, default="openflamingo/OpenFlamingo-9B-vitl-mpt7b", help="Model name or path")
|
|
||||||
parser.add_argument("--force_cpu", action="store_true", default=False, help="force using CPU even if GPU is available")
|
|
||||||
parser.add_argument("--min_new_tokens", type=int, default=20, help="minimum number of tokens to generate")
|
|
||||||
parser.add_argument("--max_new_tokens", type=int, default=50, help="maximum number of tokens to generate")
|
|
||||||
parser.add_argument("--num_beams", type=int, default=8, help="number of beams, more is more accurate but slower")
|
|
||||||
parser.add_argument("--prompt", type=str, default="Output: ", help="prompt to use for generation, default is 'Output: '")
|
|
||||||
parser.add_argument("--temperature", type=float, default=1.0, help="temperature for sampling, 1.0 is default")
|
|
||||||
parser.add_argument("--top_k", type=int, default=0, help="top_k sampling, 0 is default")
|
|
||||||
parser.add_argument("--top_p", type=float, default=1.0, help="top_p sampling, 1.0 is default")
|
|
||||||
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty, 1.0 is default")
|
|
||||||
parser.add_argument("--length_penalty", type=float, default=1.0, help="length penalty, 1.0 is default")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
print(f"** OPEN-FLAMINGO ** Captioning files in: {args.data_root}")
|
|
||||||
print(f"** Using model: {args.model}")
|
|
||||||
main(args)
|
|
|
@ -76,7 +76,12 @@ def main(args):
|
||||||
full_file_path = os.path.join(root, file)
|
full_file_path = os.path.join(root, file)
|
||||||
image = Image.open(full_file_path)
|
image = Image.open(full_file_path)
|
||||||
|
|
||||||
inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt")
|
if args.phrase_mode:
|
||||||
|
text = GROUNDING + "".join(["<phrase>" + x.strip() + "</phrase>" for x in args.prompt.split(",")])
|
||||||
|
else:
|
||||||
|
text = GROUNDING + args.prompt
|
||||||
|
|
||||||
|
inputs = processor(text=text, images=image, return_tensors="pt")
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=args.dtype != "fp32", dtype=dtype):
|
with torch.cuda.amp.autocast(enabled=args.dtype != "fp32", dtype=dtype):
|
||||||
generated_ids = model.generate(
|
generated_ids = model.generate(
|
||||||
|
@ -98,7 +103,7 @@ def main(args):
|
||||||
print(f"File: {full_file_path}, Generated caption: {processed_text}")
|
print(f"File: {full_file_path}, Generated caption: {processed_text}")
|
||||||
|
|
||||||
name = os.path.splitext(full_file_path)[0]
|
name = os.path.splitext(full_file_path)[0]
|
||||||
if not os.path.exists(f"{name}.txt") or args.overwrite:
|
if (not os.path.exists(f"{name}.txt") or args.overwrite) and not args.save_entities_only:
|
||||||
with open(f"{name}.txt", "w") as f:
|
with open(f"{name}.txt", "w") as f:
|
||||||
f.write(processed_text)
|
f.write(processed_text)
|
||||||
|
|
||||||
|
@ -114,15 +119,20 @@ if __name__ == "__main__":
|
||||||
parser.description = "Kosmos-2 captioning script"
|
parser.description = "Kosmos-2 captioning script"
|
||||||
parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption")
|
parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption")
|
||||||
parser.add_argument("--prompt", type=str, default="Describe this image in detail: ", help="Prompt for generating caption")
|
parser.add_argument("--prompt", type=str, default="Describe this image in detail: ", help="Prompt for generating caption")
|
||||||
|
parser.add_argument("--phrase_mode", action="store_true", default=False, help="uses 'phrase mode' grounding, interprets prompt as csv list of phrases to ground.")
|
||||||
parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved")
|
parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved")
|
||||||
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate")
|
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate")
|
||||||
parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file")
|
parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file")
|
||||||
|
parser.add_argument("--save_entities_only", action="store_true", default=False, help="Only save coord box with entities to a separate .ent file, do not write caption .txt")
|
||||||
parser.add_argument("--overwrite", action="store_true", default=False, help="will overwrite .txt and .ent files if they exist")
|
parser.add_argument("--overwrite", action="store_true", default=False, help="will overwrite .txt and .ent files if they exist")
|
||||||
parser.add_argument("--cpu", action="store_true", default=False, help="use cpu instead of cuda")
|
parser.add_argument("--cpu", action="store_true", default=False, help="use cpu instead of cuda")
|
||||||
parser.add_argument("--dtype", type=str, default="fp16", help="force a different dtype if using GPU (fp16, bf16, fp32) (default: fp16)")
|
parser.add_argument("--dtype", type=str, default="fp16", help="force a different dtype if using GPU (fp16, bf16, fp32) (default: fp16)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
parser.print_help()
|
parser.print_help()
|
||||||
|
|
||||||
|
if args.save_entities_only:
|
||||||
|
args.save_entities = True
|
||||||
|
|
||||||
if not args.prompt.startswith(" "):
|
if not args.prompt.startswith(" "):
|
||||||
args.prompt = " " + args.prompt
|
args.prompt = " " + args.prompt
|
||||||
|
|
||||||
|
|
|
@ -1,34 +1,27 @@
|
||||||
# Captioning tools
|
# Captioning tools
|
||||||
|
|
||||||
## Open-Flamingo
|
## CogVLM
|
||||||
|
|
||||||
#### Note: Open-Flamingo currently only works on Torch 2.0.1. If you want to use it, you will have to backdate your torch installation, which will break features in the trainer. I recommend making a separate environment for Open Flamingo captioning instead. You can run through normal install, then `pip install open-flamingo` in the separate envirment to back date torch and make that install open-flamingo only.
|
[CogVLM](https://github.com/THUDM/CogVLM) is, so far, the best model for generating synthetic captions. The script for Cog is enhanced, so read the [CogVLM README](CAPTION_COG.md) for more information.
|
||||||
|
|
||||||
`python caption_fl.py --data_root input --min_new_tokens 20 --max_new_tokens 30 --num_beams 3 --model "openflamingo/OpenFlamingo-9B-vitl-mpt7b"`
|
## Kosmos-2
|
||||||
|
|
||||||
This script uses two example image/caption pairs located in the `/example` folder to prime the system to caption, then captions the images in the input folder. It will save a `.txt` file of the same base filename with the caption in the same folder.
|
Microsoft's [Kosmos-2](https://huggingface.co/microsoft/kosmos-2-patch14-224) is significantly lighter weight than Cog, using <5GB of VRAM and generating captions in under a second on a RTX 3090.
|
||||||
|
|
||||||
This script currently requires an AMPERE or newer GPU due to using bfloat16.
|
It has the capability to output grounding bounding boxes.
|
||||||
|
|
||||||
**Trying out different example image/caption pairs will influence how the system captions the input images.** Adding more examples slows processing.
|
Run `python caption_kosmos2.py --help` to get a list of options.
|
||||||
|
|
||||||
Supported models:
|
### _Kosmos-2 grounding_
|
||||||
|
|
||||||
* `openflamingo/OpenFlamingo-3B-vitl-mpt1b` Small model, requires 8 GB VRAM a num_beams 3, or 12 GB at num_beams 16
|
Kosmos-2can generate bounding boxes for the "grounding" of the caption. This is useful for identifying specific objects in the image in 2D space, which can be useful in later piplines.
|
||||||
* `openflamingo/OpenFlamingo-9B-vitl-mpt7b` Large model, requires 24 GB VRAM at num_beams 3, or 36.7gb at num_beams 32
|
|
||||||
|
|
||||||
The small model with more beams (ex. 16) performs well with details and should not be immediately discounted.
|
It's worth reading the documentation [here](https://huggingface.co/microsoft/kosmos-2-patch14-224) to understand the grounding output.
|
||||||
|
|
||||||
The larger model is more accurate with proper names (i.e. identifying well-known celebrities, objects, or locations) and seems to exhibit a larger vocabulary.
|
`--save_entities` outputs a '.ent' file with bounding box information. The entities identified will be based on what caption is produced.
|
||||||
|
|
||||||
Primary params:
|
`--phrase_mode` This modifies how the model is called, wrapping phrases in \<phrase> tags. This also interprets your prompt as a CSV, wrapping each item in a phrase tag. You might use it with `--prompt "dog,cat,tree"` for instance. *This is not a gaurantee your phrases will be found and output into the grounding output file.*
|
||||||
|
|
||||||
* `--num_beams 3` increasing uses more VRAM and runs slower, may improve detail, but can increase hallicunations
|
`--save_entities_only` This will not attempt to write the caption into the .txt file at all. **This is recommended with `--phrase_mode`**. Using this option forces `--save_entities`.
|
||||||
* `--min_new_tokens 20` and `--max_new_tokens 35` control the length of the caption
|
|
||||||
|
|
||||||
Other settings:
|
There is a trivial/dumb UI for viewing the grounding in the scripts folder. Launch it with `python scripts/grounding_ui.py` and it will open a window allowing you to select a directory, and it will display the images and bounding boxes.
|
||||||
|
|
||||||
* `--force_cpu` forces to use CPU even if a CUDA device is present
|
|
||||||
* `--temperature 1.0` relates to randomness used for next token chosen
|
|
||||||
* `--repetition_penalty 1.0` penalizes repeating tokens/words, can adjust up if you see repeated terms
|
|
||||||
* `--length_penalty 1.0` penalizes longer captions
|
|
||||||
|
|
|
@ -120,7 +120,7 @@ I would recommend not setting any of these and leave the default values until yo
|
||||||
|
|
||||||
`--no_repeat_ngram_size 3` prevents the same n-gram (successive token sequence) from being repeated in the output. Can help prevent the model from repeating itself.
|
`--no_repeat_ngram_size 3` prevents the same n-gram (successive token sequence) from being repeated in the output. Can help prevent the model from repeating itself.
|
||||||
|
|
||||||
`--bad_words "foo,bar"` Attempts to prevent the model from using these words in the output caption. Comma-delimited.
|
`--bad_words "foo,bar"` Attempts to prevent the model from using these words in the output caption. Comma-delimited. Very useful, consider trying `"depicts,poses,posing,showcases,appears,suggests"` to get more concise phrasing in captions. This is not a guarantee, due to [different tokenizations](https://github.com/huggingface/transformers/issues/17504) being possible for a given bad_word.
|
||||||
|
|
||||||
`--force_word "photograph,Spain"` Attempts to force the model to include the words in the output caption. Comma-delimited.
|
`--force_word "photograph,Spain"` Attempts to force the model to include the words in the output caption. Comma-delimited.
|
||||||
|
|
||||||
|
@ -128,7 +128,7 @@ I would recommend not setting any of these and leave the default values until yo
|
||||||
|
|
||||||
`--max_new_tokens 120` Truncates output after n tokens. May cut off captions abruptly.
|
`--max_new_tokens 120` Truncates output after n tokens. May cut off captions abruptly.
|
||||||
|
|
||||||
`--no_repeat_ngram_size 3` prevents the same n-gram from being repeated in the output. Default is 0, which means no n-gram is prevented from repeating. Setting this to 2 or 3 can help prevent the model from repeating itself.
|
`--no_repeat_ngram_size 3` prevents the same n-gram (sequence of size n-tokens) from being repeated in the output. Default is 0, which means no n-gram is prevented from repeating. Setting this to 2 or 3 can help prevent the model from repeating itself.
|
||||||
|
|
||||||
`--min_new_tokens 5` Force the model to produce at least n tokens.
|
`--min_new_tokens 5` Force the model to produce at least n tokens.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue