enhance cog caption script

This commit is contained in:
Victor Hall 2024-04-02 16:35:49 -04:00
parent 65bd7b3b92
commit 7ea9676da8
2 changed files with 38 additions and 18 deletions

View File

@ -39,10 +39,12 @@ from utils.patch_cog import patch_cog
from data.generators import image_path_generator, SUPPORTED_EXT
try:
from moai.load_moai import prepare_moai
from moai.load_moai import prepare_moai
except ImportError:
print("moai not found, skipping")
Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit
IMAGE_SIZE: int = 490
PATCH_SIZE: int = 14
@ -89,10 +91,10 @@ def build_conversation_input_ids(
attention_mask = [1] * len(input_ids)
return {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
'images': images,
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"token_type_ids": torch.tensor(token_type_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"images": images,
}
def get_gpu_memory_map():
@ -138,7 +140,7 @@ class MoaiManager:
self.sgg_model = None
self.ocr_model = None
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str='fp16'):
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
= prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
self.moai_model = moai_model
@ -162,13 +164,13 @@ class MoaiManager:
od_processor=self.od_processor,
sgg_model=self.sgg_model,
ocr_model=self.ocr_model,
device='cuda:0')
device="cuda:0")
return moai_inputs
def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any:
with torch.inference_mode():
generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache)
answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split('[U')[0]
answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0]
return answer
class CogVLMManager:
@ -178,7 +180,7 @@ class CogVLMManager:
self.model = None
def load_model(self):
self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
@ -270,7 +272,7 @@ def main(args):
image = Image.open(image_path)
try:
image = image.convert('RGB')
image = image.convert("RGB")
image = ImageOps.exif_transpose(image)
except Exception as e:
logging.warning(f"Non-fatal error processing {image_path}: {e}")
@ -283,12 +285,12 @@ def main(args):
inputs = build_conversation_input_ids(tokenizer, query=prompt, history=[], images=[image], starts_with=args.starts_with) # chat mode
inputs = {
'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
'images': [[inputs['images'][0].to('cuda').to(torch.bfloat16)] for _ in range(args.num_beams)],
'output_hidden_states': True,
'return_dict': True
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
"token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"),
"attention_mask": inputs['attention_mask'].unsqueeze(0).to("cuda"),
"images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)] for _ in range(args.num_beams)],
"output_hidden_states": True,
"return_dict": True
}
# print(f"** inputs type: {type(inputs)}") # dict
@ -407,10 +409,19 @@ if __name__ == "__main__":
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.")
args = argparser.parse_args()
args, unknown_args = argparser.parse_known_args()
configure_logging(args)
unknown_args_dict = {}
for i in range(0, len(unknown_args), 2):
key = unknown_args[i].lstrip('-') # Remove the leading '--'
value = unknown_args[i + 1]
unknown_args_dict[key] = value
setattr(args, key, value) # Add each unknown argument to the args namespace
logging.info(f"** Unknown args have been added to args for plugins: {Fore.LIGHTGREEN_EX}{unknown_args_dict}{Style.RESET_ALL}")
print(DESCRIPTION)
print(EXAMPLES)

View File

@ -155,11 +155,20 @@ class FromFolderMetadataJson(PromptIdentityBase):
args=args)
self.metadata_provider = MetadataProvider()
def _clean_metadata(self, metadata: dict, args) -> dict:
if "remove_keys" in args:
keys = args.remove_keys.split(",")
logging.debug(f"Removing keys: {keys}")
for key in keys:
metadata.pop(key, None)
logging.debug(f"Removed key: {key}")
def _from_metadata_json(self, args:Namespace) -> dict:
image_path = args.image_path
image_dir = os.path.dirname(image_path)
metadata_json_path = os.path.join(image_dir, "metadata.json")
metadata = self.metadata_provider._get_metadata_dict(metadata_json_path)
self._clean_metadata(metadata, args)
metadata = json.dumps(metadata, indent=2)
prompt = self._add_hint_to_prompt(f"metadata: {metadata}", args.prompt)
return prompt