update kosmos2 caption to use cuda by default
This commit is contained in:
parent
4fb64fed66
commit
24a692801a
|
@ -36,13 +36,27 @@ def get_gpu_memory_map():
|
|||
return info.used/1024/1024
|
||||
|
||||
def remove_starting_string(a, b):
|
||||
print(a)
|
||||
print(b)
|
||||
if b.startswith(a):
|
||||
return b[len(a):] # Remove string A from the beginning of string B
|
||||
elif b.strip().startswith(a.strip()):
|
||||
return b.strip()[len(a.strip()):]
|
||||
else:
|
||||
return b
|
||||
|
||||
def main(args):
|
||||
for root, dirs, files in os.walk(args.data_root):
|
||||
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
|
||||
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
|
||||
|
||||
if not args.cpu:
|
||||
#move to cuda and use float16
|
||||
model = model.half().cuda()
|
||||
print(f"Using cuda, model dtype: {model.dtype}")
|
||||
else:
|
||||
print(f"Using cpu, model dtype: {model.dtype}")
|
||||
|
||||
for root, dirs, files in os.walk(args.data_root):
|
||||
for file in files:
|
||||
#get file extension
|
||||
ext = os.path.splitext(file)[1]
|
||||
|
@ -51,8 +65,7 @@ def main(args):
|
|||
|
||||
full_file_path = os.path.join(root, file)
|
||||
image = Image.open(full_file_path)
|
||||
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
|
||||
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
|
||||
|
||||
|
||||
full_file_path = os.path.join(root, file)
|
||||
image = Image.open(full_file_path)
|
||||
|
@ -60,11 +73,11 @@ def main(args):
|
|||
inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt")
|
||||
|
||||
generated_ids = model.generate(
|
||||
pixel_values=inputs["pixel_values"],
|
||||
input_ids=inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"],
|
||||
pixel_values=inputs["pixel_values"].cuda() if not args.cpu else inputs["pixel_values"],
|
||||
input_ids=inputs["input_ids"].cuda() if not args.cpu else inputs["input_ids"],
|
||||
attention_mask=inputs["attention_mask"].cuda() if not args.cpu else inputs["attention_mask"],
|
||||
image_embeds=None,
|
||||
image_embeds_position_mask=inputs["image_embeds_position_mask"],
|
||||
image_embeds_position_mask=inputs["image_embeds_position_mask"].cuda() if not args.cpu else inputs["image_embeds_position_mask"],
|
||||
use_cache=True,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
)
|
||||
|
@ -75,25 +88,40 @@ def main(args):
|
|||
if not args.keep_prompt:
|
||||
processed_text = remove_starting_string(args.prompt, processed_text)
|
||||
|
||||
print(f"File: {image}, Generated caption: {processed_text}")
|
||||
print(f"File: {full_file_path}, Generated caption: {processed_text}")
|
||||
|
||||
name = os.path.splitext(full_file_path)[0]
|
||||
if not os.path.exists(f"{name}.txt") or args.over_write:
|
||||
if not os.path.exists(f"{name}.txt") or args.overwrite:
|
||||
with open(f"{name}.txt", "w") as f:
|
||||
f.write(processed_text)
|
||||
|
||||
if args.save_entities and (not os.path.exists(f"{name}.ent") or args.over_write):
|
||||
if args.save_entities and (not os.path.exists(f"{name}.ent") or args.overwrite):
|
||||
with open(f"{name}.ent", "w") as entities_file:
|
||||
entities_file.write(entities)
|
||||
gpu_mb_used = get_gpu_memory_map()
|
||||
print(f"gpu usage: {gpu_mb_used:.1f} mb")
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Kosmos-2 captioning script")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.description = "Kosmos-2 captioning script"
|
||||
parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption")
|
||||
parser.add_argument("--prompt", type=str, default="An image of", help="Prompt for generating caption")
|
||||
parser.add_argument("--prompt", type=str, default="Describe this image in detail: ", help="Prompt for generating caption")
|
||||
parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved")
|
||||
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate")
|
||||
parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file")
|
||||
parser.add_argument("--over_write", action="store_true", default=False, help="will overwrite txt and ent files if they exist")
|
||||
parser.add_argument("--overwrite", action="store_true", default=False, help="will overwrite txt and ent files if they exist")
|
||||
parser.add_argument("--cpu", action="store_true", default=False, help="use cpu instead of cuda")
|
||||
args = parser.parse_args()
|
||||
parser.print_help()
|
||||
|
||||
if not args.prompt.startswith(" "):
|
||||
args.prompt = " " + args.prompt
|
||||
|
||||
print(f"Captioning images in {args.data_root} with prompt: {args.prompt}")
|
||||
print(f"Ideas for prompts:")
|
||||
print(f" Describe this image in detail: (default)")
|
||||
print(f" An image of ")
|
||||
print(f" A two sentence description of this image:")
|
||||
print()
|
||||
main(args)
|
Loading…
Reference in New Issue