adding filters for llava
This commit is contained in:
parent
d29368eb85
commit
1fe3d9f4e5
|
@ -163,16 +163,24 @@ class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner
|
|||
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
|
||||
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
|
||||
|
||||
def _get_full_sentences(self, caption, args):
|
||||
logging.debug(f"**DEBUG: XtunerLlava presplit caption: {caption}")
|
||||
if args.max_length is not None and len(caption) > args.max_length:
|
||||
caption = caption[:args.max_length]
|
||||
|
||||
def _clean_caption(self, caption, args):
|
||||
"""
|
||||
Clean up the caption by removing any newlines and excess whitespace, and removes some nonsense Llava adds.
|
||||
"""
|
||||
logging.debug(f"**Llava pre-cleaning caption: {caption}")
|
||||
caption = caption.split(".")
|
||||
#sentence_count = min(4, len(caption))
|
||||
caption = ". ".join(caption[0:-1]) + "."
|
||||
caption = caption.replace("\n","")
|
||||
caption = caption.replace(" "," ")
|
||||
caption = re.sub(r"The image does not contain .*?\.", "", caption)
|
||||
caption = re.sub(r"Please note that this description is based on .*?\.", "", caption)
|
||||
caption = re.sub(r", adding to .*? overall appearance", "", caption)
|
||||
caption = re.sub(r"The rest of .*? is not visible in the image, focusing .*?\.", "", caption)
|
||||
caption = re.sub(r"hinting at .*?\.", "", caption)
|
||||
caption = caption.replace(", who is the main subject of the image,", "")
|
||||
|
||||
logging.debug(f"**DEBUG: caption: {caption}")
|
||||
logging.debug(f"**Llava post-cleaning caption: {caption}")
|
||||
return caption
|
||||
|
||||
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
||||
|
@ -182,6 +190,10 @@ class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner
|
|||
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
|
||||
# inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
|
||||
|
||||
# inputs['input_ids'].shape: torch.Size([1, 34])
|
||||
# inputs['attention_mask'].shape: torch.Size([1, 34])
|
||||
# inputs['pixel_values'].shape: torch.Size([1, 3, 336, 336])
|
||||
|
||||
inputs = {
|
||||
"input_ids": inputs["input_ids"],
|
||||
"attention_mask": inputs['attention_mask'],
|
||||
|
@ -196,7 +208,7 @@ class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner
|
|||
|
||||
caption = self.processor.decode(outputs[0][len_inputs:], skip_special_tokens=True)
|
||||
|
||||
caption = self._get_full_sentences(caption, args)
|
||||
caption = self._clean_caption(caption, args)
|
||||
return caption
|
||||
|
||||
class MoaiManager:
|
||||
|
@ -312,6 +324,10 @@ class CogVLMManager(BaseModelWrapper):
|
|||
|
||||
inputs = self._build_conversation_input_ids(query=prompt, history=history, images=[image], starts_with=args.starts_with)
|
||||
|
||||
# inputs['input_ids'].shape: torch.Size([1259])
|
||||
# inputs['attention_mask'].shape: torch.Size([1259])
|
||||
# inputs['images'][0].shape: torch.Size([3, 490, 490])
|
||||
|
||||
inputs = {
|
||||
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
|
||||
"token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"),
|
||||
|
@ -320,6 +336,12 @@ class CogVLMManager(BaseModelWrapper):
|
|||
"output_hidden_states": True,
|
||||
"return_dict": True
|
||||
}
|
||||
# inputs['input_ids'].shape: torch.Size([1, 1259])
|
||||
# inputs['attention_mask'].shape: torch.Size([1, 1259])
|
||||
# inputs['images'][0][0].shape: torch.Size([3, 490, 490])
|
||||
# len(inputs['images'][0]): 1
|
||||
# len(inputs['images'][0][0]): 3
|
||||
|
||||
outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||
#print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}")
|
||||
#print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}")
|
||||
|
@ -374,7 +396,6 @@ def main(args):
|
|||
bad_words = args.bad_words.split(",") if args.bad_words is not None else []
|
||||
logging.info(f"** bad_words: {Fore.LIGHTGREEN_EX}{bad_words}{Style.RESET_ALL}")
|
||||
bad_words_ids = model_wrapper.tokenizer(bad_words, add_special_tokens=False)["input_ids"] if bad_words else []
|
||||
#print(bad_words_ids)
|
||||
|
||||
logging.info(f"** gen_kwargs: \n{Fore.LIGHTGREEN_EX}{gen_kwargs}{Style.RESET_ALL}")
|
||||
|
||||
|
@ -475,6 +496,7 @@ DESCRIPTION = f"** {Fore.LIGHTBLUE_EX}CogVLM captioning script{Style.RESET_ALL}
|
|||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch processing. Does NOT work with COG! (def: 1)")
|
||||
argparser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||
argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.")
|
||||
argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling")
|
||||
|
|
Loading…
Reference in New Issue