add phrase grounding to kosmos2 caption script

This commit is contained in:
Victor Hall 2024-03-03 15:18:54 -05:00
parent a55ab816af
commit 3256f9e33c
1 changed files with 12 additions and 2 deletions

View File

@ -76,7 +76,12 @@ def main(args):
full_file_path = os.path.join(root, file) full_file_path = os.path.join(root, file)
image = Image.open(full_file_path) image = Image.open(full_file_path)
inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt") if args.phrase_mode:
text = GROUNDING + "".join(["<phrase>" + x.strip() + "</phrase>" for x in args.prompt.split(",")])
else:
text = GROUNDING + args.prompt
inputs = processor(text=text, images=image, return_tensors="pt")
with torch.cuda.amp.autocast(enabled=args.dtype != "fp32", dtype=dtype): with torch.cuda.amp.autocast(enabled=args.dtype != "fp32", dtype=dtype):
generated_ids = model.generate( generated_ids = model.generate(
@ -98,7 +103,7 @@ def main(args):
print(f"File: {full_file_path}, Generated caption: {processed_text}") print(f"File: {full_file_path}, Generated caption: {processed_text}")
name = os.path.splitext(full_file_path)[0] name = os.path.splitext(full_file_path)[0]
if not os.path.exists(f"{name}.txt") or args.overwrite: if not os.path.exists(f"{name}.txt") or args.overwrite and not args.save_entities_only:
with open(f"{name}.txt", "w") as f: with open(f"{name}.txt", "w") as f:
f.write(processed_text) f.write(processed_text)
@ -114,15 +119,20 @@ if __name__ == "__main__":
parser.description = "Kosmos-2 captioning script" parser.description = "Kosmos-2 captioning script"
parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption") parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption")
parser.add_argument("--prompt", type=str, default="Describe this image in detail: ", help="Prompt for generating caption") parser.add_argument("--prompt", type=str, default="Describe this image in detail: ", help="Prompt for generating caption")
parser.add_argument("--phrase_mode", action="store_true", default=False, help="uses 'phrase mode' grounding, interprets prompt as csv list of phrases to ground.")
parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved") parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved")
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate") parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate")
parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file") parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file")
parser.add_argument("--save_entities_only", action="store_true", default=False, help="Only save coord box with entities to a separate .ent file, do not write caption .txt")
parser.add_argument("--overwrite", action="store_true", default=False, help="will overwrite .txt and .ent files if they exist") parser.add_argument("--overwrite", action="store_true", default=False, help="will overwrite .txt and .ent files if they exist")
parser.add_argument("--cpu", action="store_true", default=False, help="use cpu instead of cuda") parser.add_argument("--cpu", action="store_true", default=False, help="use cpu instead of cuda")
parser.add_argument("--dtype", type=str, default="fp16", help="force a different dtype if using GPU (fp16, bf16, fp32) (default: fp16)") parser.add_argument("--dtype", type=str, default="fp16", help="force a different dtype if using GPU (fp16, bf16, fp32) (default: fp16)")
args = parser.parse_args() args = parser.parse_args()
parser.print_help() parser.print_help()
if args.save_entities_only:
args.save_entities = True
if not args.prompt.startswith(" "): if not args.prompt.startswith(" "):
args.prompt = " " + args.prompt args.prompt = " " + args.prompt