update kosmos2 caption to use cuda by default

This commit is contained in:
Victor Hall 2023-11-02 22:53:10 -04:00
parent 4fb64fed66
commit 24a692801a
1 changed files with 40 additions and 12 deletions

View File

@ -36,12 +36,26 @@ def get_gpu_memory_map():
return info.used/1024/1024
def remove_starting_string(a, b):
print(a)
print(b)
if b.startswith(a):
return b[len(a):] # Remove string A from the beginning of string B
elif b.strip().startswith(a.strip()):
return b.strip()[len(a.strip()):]
else:
return b
def main(args):
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
if not args.cpu:
#move to cuda and use float16
model = model.half().cuda()
print(f"Using cuda, model dtype: {model.dtype}")
else:
print(f"Using cpu, model dtype: {model.dtype}")
for root, dirs, files in os.walk(args.data_root):
for file in files:
#get file extension
@ -51,8 +65,7 @@ def main(args):
full_file_path = os.path.join(root, file)
image = Image.open(full_file_path)
model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224")
processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224")
full_file_path = os.path.join(root, file)
image = Image.open(full_file_path)
@ -60,11 +73,11 @@ def main(args):
inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt")
generated_ids = model.generate(
pixel_values=inputs["pixel_values"],
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
pixel_values=inputs["pixel_values"].cuda() if not args.cpu else inputs["pixel_values"],
input_ids=inputs["input_ids"].cuda() if not args.cpu else inputs["input_ids"],
attention_mask=inputs["attention_mask"].cuda() if not args.cpu else inputs["attention_mask"],
image_embeds=None,
image_embeds_position_mask=inputs["image_embeds_position_mask"],
image_embeds_position_mask=inputs["image_embeds_position_mask"].cuda() if not args.cpu else inputs["image_embeds_position_mask"],
use_cache=True,
max_new_tokens=args.max_new_tokens,
)
@ -75,25 +88,40 @@ def main(args):
if not args.keep_prompt:
processed_text = remove_starting_string(args.prompt, processed_text)
print(f"File: {image}, Generated caption: {processed_text}")
print(f"File: {full_file_path}, Generated caption: {processed_text}")
name = os.path.splitext(full_file_path)[0]
if not os.path.exists(f"{name}.txt") or args.over_write:
if not os.path.exists(f"{name}.txt") or args.overwrite:
with open(f"{name}.txt", "w") as f:
f.write(processed_text)
if args.save_entities and (not os.path.exists(f"{name}.ent") or args.over_write):
if args.save_entities and (not os.path.exists(f"{name}.ent") or args.overwrite):
with open(f"{name}.ent", "w") as entities_file:
entities_file.write(entities)
gpu_mb_used = get_gpu_memory_map()
print(f"gpu usage: {gpu_mb_used:.1f} mb")
if __name__ == "__main__":
print("Kosmos-2 captioning script")
parser = argparse.ArgumentParser()
parser.description = "Kosmos-2 captioning script"
parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption")
parser.add_argument("--prompt", type=str, default="An image of", help="Prompt for generating caption")
parser.add_argument("--prompt", type=str, default="Describe this image in detail: ", help="Prompt for generating caption")
parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved")
parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate")
parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file")
parser.add_argument("--over_write", action="store_true", default=False, help="will overwrite txt and ent files if they exist")
parser.add_argument("--overwrite", action="store_true", default=False, help="will overwrite txt and ent files if they exist")
parser.add_argument("--cpu", action="store_true", default=False, help="use cpu instead of cuda")
args = parser.parse_args()
parser.print_help()
if not args.prompt.startswith(" "):
args.prompt = " " + args.prompt
print(f"Captioning images in {args.data_root} with prompt: {args.prompt}")
print(f"Ideas for prompts:")
print(f" Describe this image in detail: (default)")
print(f" An image of ")
print(f" A two sentence description of this image:")
print()
main(args)