From 7ea9676da8cc1b04776e55bddbfd4d4a426fed59 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Tue, 2 Apr 2024 16:35:49 -0400 Subject: [PATCH] enhance cog caption script --- caption_cog.py | 47 +++++++++++++++++++++++--------------- plugins/caption_plugins.py | 9 ++++++++ 2 files changed, 38 insertions(+), 18 deletions(-) diff --git a/caption_cog.py b/caption_cog.py index cae4d4a..8277234 100644 --- a/caption_cog.py +++ b/caption_cog.py @@ -39,10 +39,12 @@ from utils.patch_cog import patch_cog from data.generators import image_path_generator, SUPPORTED_EXT try: - from moai.load_moai import prepare_moai + from moai.load_moai import prepare_moai except ImportError: print("moai not found, skipping") +Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit + IMAGE_SIZE: int = 490 PATCH_SIZE: int = 14 @@ -89,10 +91,10 @@ def build_conversation_input_ids( attention_mask = [1] * len(input_ids) return { - 'input_ids': torch.tensor(input_ids, dtype=torch.long), - 'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long), - 'attention_mask': torch.tensor(attention_mask, dtype=torch.long), - 'images': images, + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long), + "attention_mask": torch.tensor(attention_mask, dtype=torch.long), + "images": images, } def get_gpu_memory_map(): @@ -138,7 +140,7 @@ class MoaiManager: self.sgg_model = None self.ocr_model = None - def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str='fp16'): + def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"): moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \ = prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype) self.moai_model = moai_model @@ -162,13 +164,13 @@ class MoaiManager: od_processor=self.od_processor, sgg_model=self.sgg_model, ocr_model=self.ocr_model, - device='cuda:0') + device="cuda:0") return moai_inputs def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any: with torch.inference_mode(): generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache) - answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split('[U')[0] + answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0] return answer class CogVLMManager: @@ -178,7 +180,7 @@ class CogVLMManager: self.model = None def load_model(self): - self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') + self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5") self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, @@ -270,7 +272,7 @@ def main(args): image = Image.open(image_path) try: - image = image.convert('RGB') + image = image.convert("RGB") image = ImageOps.exif_transpose(image) except Exception as e: logging.warning(f"Non-fatal error processing {image_path}: {e}") @@ -283,12 +285,12 @@ def main(args): inputs = build_conversation_input_ids(tokenizer, query=prompt, history=[], images=[image], starts_with=args.starts_with) # chat mode inputs = { - 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'), - 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'), - 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'), - 'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)] for _ in range(args.num_beams)], - 'output_hidden_states': True, - 'return_dict': True + "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"), + "token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"), + "attention_mask": inputs['attention_mask'].unsqueeze(0).to("cuda"), + "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)] for _ in range(args.num_beams)], + "output_hidden_states": True, + "return_dict": True } # print(f"** inputs type: {type(inputs)}") # dict @@ -407,10 +409,19 @@ if __name__ == "__main__": argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.") argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.") argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.") - args = argparser.parse_args() - + args, unknown_args = argparser.parse_known_args() + configure_logging(args) + unknown_args_dict = {} + for i in range(0, len(unknown_args), 2): + key = unknown_args[i].lstrip('-') # Remove the leading '--' + value = unknown_args[i + 1] + unknown_args_dict[key] = value + setattr(args, key, value) # Add each unknown argument to the args namespace + + logging.info(f"** Unknown args have been added to args for plugins: {Fore.LIGHTGREEN_EX}{unknown_args_dict}{Style.RESET_ALL}") + print(DESCRIPTION) print(EXAMPLES) diff --git a/plugins/caption_plugins.py b/plugins/caption_plugins.py index 7592c54..0af848a 100644 --- a/plugins/caption_plugins.py +++ b/plugins/caption_plugins.py @@ -155,11 +155,20 @@ class FromFolderMetadataJson(PromptIdentityBase): args=args) self.metadata_provider = MetadataProvider() + def _clean_metadata(self, metadata: dict, args) -> dict: + if "remove_keys" in args: + keys = args.remove_keys.split(",") + logging.debug(f"Removing keys: {keys}") + for key in keys: + metadata.pop(key, None) + logging.debug(f"Removed key: {key}") + def _from_metadata_json(self, args:Namespace) -> dict: image_path = args.image_path image_dir = os.path.dirname(image_path) metadata_json_path = os.path.join(image_dir, "metadata.json") metadata = self.metadata_provider._get_metadata_dict(metadata_json_path) + self._clean_metadata(metadata, args) metadata = json.dumps(metadata, indent=2) prompt = self._add_hint_to_prompt(f"metadata: {metadata}", args.prompt) return prompt