update kosmos2 caption to use cuda by default
This commit is contained in:
parent
4fb64fed66
commit
24a692801a
|
@ -36,12 +36,26 @@ def get_gpu_memory_map():
|
||||||
return info.used/1024/1024
|
return info.used/1024/1024
|
||||||
|
|
||||||
def remove_starting_string(a, b):
|
def remove_starting_string(a, b):
|
||||||
|
print(a)
|
||||||
|
print(b)
|
||||||
if b.startswith(a):
|
if b.startswith(a):
|
||||||
return b[len(a):] # Remove string A from the beginning of string B
|
return b[len(a):] # Remove string A from the beginning of string B
|
||||||
|
elif b.strip().startswith(a.strip()):
|
||||||
|
return b.strip()[len(a.strip()):]
|
||||||
else:
|
else:
|
||||||
return b
|
return b
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
|
||||||
|
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
|
||||||
|
|
||||||
|
if not args.cpu:
|
||||||
|
#move to cuda and use float16
|
||||||
|
model = model.half().cuda()
|
||||||
|
print(f"Using cuda, model dtype: {model.dtype}")
|
||||||
|
else:
|
||||||
|
print(f"Using cpu, model dtype: {model.dtype}")
|
||||||
|
|
||||||
for root, dirs, files in os.walk(args.data_root):
|
for root, dirs, files in os.walk(args.data_root):
|
||||||
for file in files:
|
for file in files:
|
||||||
#get file extension
|
#get file extension
|
||||||
|
@ -51,8 +65,7 @@ def main(args):
|
||||||
|
|
||||||
full_file_path = os.path.join(root, file)
|
full_file_path = os.path.join(root, file)
|
||||||
image = Image.open(full_file_path)
|
image = Image.open(full_file_path)
|
||||||
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
|
|
||||||
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
|
|
||||||
|
|
||||||
full_file_path = os.path.join(root, file)
|
full_file_path = os.path.join(root, file)
|
||||||
image = Image.open(full_file_path)
|
image = Image.open(full_file_path)
|
||||||
|
@ -60,11 +73,11 @@ def main(args):
|
||||||
inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt")
|
inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt")
|
||||||
|
|
||||||
generated_ids = model.generate(
|
generated_ids = model.generate(
|
||||||
pixel_values=inputs["pixel_values"],
|
pixel_values=inputs["pixel_values"].cuda() if not args.cpu else inputs["pixel_values"],
|
||||||
input_ids=inputs["input_ids"],
|
input_ids=inputs["input_ids"].cuda() if not args.cpu else inputs["input_ids"],
|
||||||
attention_mask=inputs["attention_mask"],
|
attention_mask=inputs["attention_mask"].cuda() if not args.cpu else inputs["attention_mask"],
|
||||||
image_embeds=None,
|
image_embeds=None,
|
||||||
image_embeds_position_mask=inputs["image_embeds_position_mask"],
|
image_embeds_position_mask=inputs["image_embeds_position_mask"].cuda() if not args.cpu else inputs["image_embeds_position_mask"],
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
max_new_tokens=args.max_new_tokens,
|
max_new_tokens=args.max_new_tokens,
|
||||||
)
|
)
|
||||||
|
@ -75,25 +88,40 @@ def main(args):
|
||||||
if not args.keep_prompt:
|
if not args.keep_prompt:
|
||||||
processed_text = remove_starting_string(args.prompt, processed_text)
|
processed_text = remove_starting_string(args.prompt, processed_text)
|
||||||
|
|
||||||
print(f"File: {image}, Generated caption: {processed_text}")
|
print(f"File: {full_file_path}, Generated caption: {processed_text}")
|
||||||
|
|
||||||
name = os.path.splitext(full_file_path)[0]
|
name = os.path.splitext(full_file_path)[0]
|
||||||
if not os.path.exists(f"{name}.txt") or args.over_write:
|
if not os.path.exists(f"{name}.txt") or args.overwrite:
|
||||||
with open(f"{name}.txt", "w") as f:
|
with open(f"{name}.txt", "w") as f:
|
||||||
f.write(processed_text)
|
f.write(processed_text)
|
||||||
|
|
||||||
if args.save_entities and (not os.path.exists(f"{name}.ent") or args.over_write):
|
if args.save_entities and (not os.path.exists(f"{name}.ent") or args.overwrite):
|
||||||
with open(f"{name}.ent", "w") as entities_file:
|
with open(f"{name}.ent", "w") as entities_file:
|
||||||
entities_file.write(entities)
|
entities_file.write(entities)
|
||||||
|
gpu_mb_used = get_gpu_memory_map()
|
||||||
|
print(f"gpu usage: {gpu_mb_used:.1f} mb")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("Kosmos-2 captioning script")
|
print("Kosmos-2 captioning script")
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.description = "Kosmos-2 captioning script"
|
||||||
parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption")
|
parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption")
|
||||||
parser.add_argument("--prompt", type=str, default="An image of", help="Prompt for generating caption")
|
parser.add_argument("--prompt", type=str, default="Describe this image in detail: ", help="Prompt for generating caption")
|
||||||
parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved")
|
parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved")
|
||||||
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate")
|
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate")
|
||||||
parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file")
|
parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file")
|
||||||
parser.add_argument("--over_write", action="store_true", default=False, help="will overwrite txt and ent files if they exist")
|
parser.add_argument("--overwrite", action="store_true", default=False, help="will overwrite txt and ent files if they exist")
|
||||||
|
parser.add_argument("--cpu", action="store_true", default=False, help="use cpu instead of cuda")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
parser.print_help()
|
||||||
|
|
||||||
|
if not args.prompt.startswith(" "):
|
||||||
|
args.prompt = " " + args.prompt
|
||||||
|
|
||||||
|
print(f"Captioning images in {args.data_root} with prompt: {args.prompt}")
|
||||||
|
print(f"Ideas for prompts:")
|
||||||
|
print(f" Describe this image in detail: (default)")
|
||||||
|
print(f" An image of ")
|
||||||
|
print(f" A two sentence description of this image:")
|
||||||
|
print()
|
||||||
main(args)
|
main(args)
|
Loading…
Reference in New Issue