diff --git a/caption_kosmos2.py b/caption_kosmos2.py index c782a92..6677c00 100644 --- a/caption_kosmos2.py +++ b/caption_kosmos2.py @@ -36,13 +36,27 @@ def get_gpu_memory_map(): return info.used/1024/1024 def remove_starting_string(a, b): + print(a) + print(b) if b.startswith(a): return b[len(a):] # Remove string A from the beginning of string B + elif b.strip().startswith(a.strip()): + return b.strip()[len(a.strip()):] else: return b def main(args): - for root, dirs, files in os.walk(args.data_root): + model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224") + processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") + + if not args.cpu: + #move to cuda and use float16 + model = model.half().cuda() + print(f"Using cuda, model dtype: {model.dtype}") + else: + print(f"Using cpu, model dtype: {model.dtype}") + + for root, dirs, files in os.walk(args.data_root): for file in files: #get file extension ext = os.path.splitext(file)[1] @@ -51,8 +65,7 @@ def main(args): full_file_path = os.path.join(root, file) image = Image.open(full_file_path) - model = AutoModelForVision2Seq.from_pretrained("microsoft/kosmos-2-patch14-224") - processor = AutoProcessor.from_pretrained("microsoft/kosmos-2-patch14-224") + full_file_path = os.path.join(root, file) image = Image.open(full_file_path) @@ -60,11 +73,11 @@ def main(args): inputs = processor(text=GROUNDING+args.prompt, images=image, return_tensors="pt") generated_ids = model.generate( - pixel_values=inputs["pixel_values"], - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], + pixel_values=inputs["pixel_values"].cuda() if not args.cpu else inputs["pixel_values"], + input_ids=inputs["input_ids"].cuda() if not args.cpu else inputs["input_ids"], + attention_mask=inputs["attention_mask"].cuda() if not args.cpu else inputs["attention_mask"], image_embeds=None, - image_embeds_position_mask=inputs["image_embeds_position_mask"], + image_embeds_position_mask=inputs["image_embeds_position_mask"].cuda() if not args.cpu else inputs["image_embeds_position_mask"], use_cache=True, max_new_tokens=args.max_new_tokens, ) @@ -75,25 +88,40 @@ def main(args): if not args.keep_prompt: processed_text = remove_starting_string(args.prompt, processed_text) - print(f"File: {image}, Generated caption: {processed_text}") + print(f"File: {full_file_path}, Generated caption: {processed_text}") name = os.path.splitext(full_file_path)[0] - if not os.path.exists(f"{name}.txt") or args.over_write: + if not os.path.exists(f"{name}.txt") or args.overwrite: with open(f"{name}.txt", "w") as f: f.write(processed_text) - if args.save_entities and (not os.path.exists(f"{name}.ent") or args.over_write): + if args.save_entities and (not os.path.exists(f"{name}.ent") or args.overwrite): with open(f"{name}.ent", "w") as entities_file: entities_file.write(entities) + gpu_mb_used = get_gpu_memory_map() + print(f"gpu usage: {gpu_mb_used:.1f} mb") if __name__ == "__main__": print("Kosmos-2 captioning script") parser = argparse.ArgumentParser() + parser.description = "Kosmos-2 captioning script" parser.add_argument("--data_root", type=str, default="input", help="Path to folder of images to caption") - parser.add_argument("--prompt", type=str, default="An image of", help="Prompt for generating caption") + parser.add_argument("--prompt", type=str, default="Describe this image in detail: ", help="Prompt for generating caption") parser.add_argument("--keep_prompt", action="store_true", default=False, help="will keep the prompt at the start of the caption when saved") parser.add_argument("--max_new_tokens", type=int, default=128, help="Maximum number of tokens to generate") parser.add_argument("--save_entities", action="store_true", default=False, help="Save coord box with entities to a separate .ent file") - parser.add_argument("--over_write", action="store_true", default=False, help="will overwrite txt and ent files if they exist") + parser.add_argument("--overwrite", action="store_true", default=False, help="will overwrite txt and ent files if they exist") + parser.add_argument("--cpu", action="store_true", default=False, help="use cpu instead of cuda") args = parser.parse_args() + parser.print_help() + + if not args.prompt.startswith(" "): + args.prompt = " " + args.prompt + + print(f"Captioning images in {args.data_root} with prompt: {args.prompt}") + print(f"Ideas for prompts:") + print(f" Describe this image in detail: (default)") + print(f" An image of ") + print(f" A two sentence description of this image:") + print() main(args) \ No newline at end of file