enhance cog caption script
This commit is contained in:
parent
65bd7b3b92
commit
7ea9676da8
|
@ -39,10 +39,12 @@ from utils.patch_cog import patch_cog
|
|||
from data.generators import image_path_generator, SUPPORTED_EXT
|
||||
|
||||
try:
|
||||
from moai.load_moai import prepare_moai
|
||||
from moai.load_moai import prepare_moai
|
||||
except ImportError:
|
||||
print("moai not found, skipping")
|
||||
|
||||
Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit
|
||||
|
||||
IMAGE_SIZE: int = 490
|
||||
PATCH_SIZE: int = 14
|
||||
|
||||
|
@ -89,10 +91,10 @@ def build_conversation_input_ids(
|
|||
attention_mask = [1] * len(input_ids)
|
||||
|
||||
return {
|
||||
'input_ids': torch.tensor(input_ids, dtype=torch.long),
|
||||
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
|
||||
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
|
||||
'images': images,
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
|
||||
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
||||
"images": images,
|
||||
}
|
||||
|
||||
def get_gpu_memory_map():
|
||||
|
@ -138,7 +140,7 @@ class MoaiManager:
|
|||
self.sgg_model = None
|
||||
self.ocr_model = None
|
||||
|
||||
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str='fp16'):
|
||||
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
|
||||
moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
|
||||
= prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
|
||||
self.moai_model = moai_model
|
||||
|
@ -162,13 +164,13 @@ class MoaiManager:
|
|||
od_processor=self.od_processor,
|
||||
sgg_model=self.sgg_model,
|
||||
ocr_model=self.ocr_model,
|
||||
device='cuda:0')
|
||||
device="cuda:0")
|
||||
return moai_inputs
|
||||
|
||||
def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any:
|
||||
with torch.inference_mode():
|
||||
generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache)
|
||||
answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split('[U')[0]
|
||||
answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0]
|
||||
return answer
|
||||
|
||||
class CogVLMManager:
|
||||
|
@ -178,7 +180,7 @@ class CogVLMManager:
|
|||
self.model = None
|
||||
|
||||
def load_model(self):
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
|
@ -270,7 +272,7 @@ def main(args):
|
|||
image = Image.open(image_path)
|
||||
|
||||
try:
|
||||
image = image.convert('RGB')
|
||||
image = image.convert("RGB")
|
||||
image = ImageOps.exif_transpose(image)
|
||||
except Exception as e:
|
||||
logging.warning(f"Non-fatal error processing {image_path}: {e}")
|
||||
|
@ -283,12 +285,12 @@ def main(args):
|
|||
inputs = build_conversation_input_ids(tokenizer, query=prompt, history=[], images=[image], starts_with=args.starts_with) # chat mode
|
||||
|
||||
inputs = {
|
||||
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
|
||||
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
|
||||
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
|
||||
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)] for _ in range(args.num_beams)],
|
||||
'output_hidden_states': True,
|
||||
'return_dict': True
|
||||
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
|
||||
"token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"),
|
||||
"attention_mask": inputs['attention_mask'].unsqueeze(0).to("cuda"),
|
||||
"images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)] for _ in range(args.num_beams)],
|
||||
"output_hidden_states": True,
|
||||
"return_dict": True
|
||||
}
|
||||
|
||||
# print(f"** inputs type: {type(inputs)}") # dict
|
||||
|
@ -407,10 +409,19 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
|
||||
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
|
||||
argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.")
|
||||
args = argparser.parse_args()
|
||||
|
||||
args, unknown_args = argparser.parse_known_args()
|
||||
|
||||
configure_logging(args)
|
||||
|
||||
unknown_args_dict = {}
|
||||
for i in range(0, len(unknown_args), 2):
|
||||
key = unknown_args[i].lstrip('-') # Remove the leading '--'
|
||||
value = unknown_args[i + 1]
|
||||
unknown_args_dict[key] = value
|
||||
setattr(args, key, value) # Add each unknown argument to the args namespace
|
||||
|
||||
logging.info(f"** Unknown args have been added to args for plugins: {Fore.LIGHTGREEN_EX}{unknown_args_dict}{Style.RESET_ALL}")
|
||||
|
||||
print(DESCRIPTION)
|
||||
print(EXAMPLES)
|
||||
|
||||
|
|
|
@ -155,11 +155,20 @@ class FromFolderMetadataJson(PromptIdentityBase):
|
|||
args=args)
|
||||
self.metadata_provider = MetadataProvider()
|
||||
|
||||
def _clean_metadata(self, metadata: dict, args) -> dict:
|
||||
if "remove_keys" in args:
|
||||
keys = args.remove_keys.split(",")
|
||||
logging.debug(f"Removing keys: {keys}")
|
||||
for key in keys:
|
||||
metadata.pop(key, None)
|
||||
logging.debug(f"Removed key: {key}")
|
||||
|
||||
def _from_metadata_json(self, args:Namespace) -> dict:
|
||||
image_path = args.image_path
|
||||
image_dir = os.path.dirname(image_path)
|
||||
metadata_json_path = os.path.join(image_dir, "metadata.json")
|
||||
metadata = self.metadata_provider._get_metadata_dict(metadata_json_path)
|
||||
self._clean_metadata(metadata, args)
|
||||
metadata = json.dumps(metadata, indent=2)
|
||||
prompt = self._add_hint_to_prompt(f"metadata: {metadata}", args.prompt)
|
||||
return prompt
|
||||
|
|
Loading…
Reference in New Issue