add phrase grounding to kosmos2 caption script
This commit is contained in:
parent
a55ab816af
commit
3256f9e33c
|
@ -76,7 +76,12 @@ def main(args):
|
|||
full_file_path = os.path.join(root, file)
|
||||
image = Image.open(full_file_path)
|
||||
|
||||
inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt")
|
||||
if args.phrase_mode:
|
||||
text = GROUNDING + "".join(["<phrase>" + x.strip() + "</phrase>" for x in args.prompt.split(",")])
|
||||
else:
|
||||
text = GROUNDING + args.prompt
|
||||
|
||||
inputs = processor(text=text, images=image, return_tensors="pt")
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=args.dtype != "fp32", dtype=dtype):
|
||||
generated_ids = model.generate(
|
||||
|
@ -98,7 +103,7 @@ def main(args):
|
|||
print(f"File: {full_file_path}, Generated caption: {processed_text}")
|
||||
|
||||
name = os.path.splitext(full_file_path)[0]
|
||||
if not os.path.exists(f"{name}.txt") or args.overwrite:
|
||||
if not os.path.exists(f"{name}.txt") or args.overwrite and not args.save_entities_only:
|
||||
with open(f"{name}.txt", "w") as f:
|
||||
f.write(processed_text)
|
||||
|
||||
|
@ -114,15 +119,20 @@ if __name__ == "__main__":
|
|||
parser.description = "Kosmos-2 captioning script"
|
||||
parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption")
|
||||
parser.add_argument("--prompt", type=str, default="Describe this image in detail: ", help="Prompt for generating caption")
|
||||
parser.add_argument("--phrase_mode", action="store_true", default=False, help="uses 'phrase mode' grounding, interprets prompt as csv list of phrases to ground.")
|
||||
parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved")
|
||||
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate")
|
||||
parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file")
|
||||
parser.add_argument("--save_entities_only", action="store_true", default=False, help="Only save coord box with entities to a separate .ent file, do not write caption .txt")
|
||||
parser.add_argument("--overwrite", action="store_true", default=False, help="will overwrite .txt and .ent files if they exist")
|
||||
parser.add_argument("--cpu", action="store_true", default=False, help="use cpu instead of cuda")
|
||||
parser.add_argument("--dtype", type=str, default="fp16", help="force a different dtype if using GPU (fp16, bf16, fp32) (default: fp16)")
|
||||
args = parser.parse_args()
|
||||
parser.print_help()
|
||||
|
||||
if args.save_entities_only:
|
||||
args.save_entities = True
|
||||
|
||||
if not args.prompt.startswith(" "):
|
||||
args.prompt = " " + args.prompt
|
||||
|
||||
|
|
Loading…
Reference in New Issue