adding filters for llava

This commit is contained in:
Victor Hall 2024-05-06 01:07:19 -04:00
parent d29368eb85
commit 1fe3d9f4e5
1 changed files with 30 additions and 8 deletions

View File

@ -163,16 +163,24 @@ class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
def _get_full_sentences(self, caption, args):
logging.debug(f"**DEBUG: XtunerLlava presplit caption: {caption}")
if args.max_length is not None and len(caption) > args.max_length:
caption = caption[:args.max_length]
def _clean_caption(self, caption, args):
"""
Clean up the caption by removing any newlines and excess whitespace, and removes some nonsense Llava adds.
"""
logging.debug(f"**Llava pre-cleaning caption: {caption}")
caption = caption.split(".")
#sentence_count = min(4, len(caption))
caption = ". ".join(caption[0:-1]) + "."
caption = caption.replace("\n","")
caption = caption.replace(" "," ")
caption = re.sub(r"The image does not contain .*?\.", "", caption)
caption = re.sub(r"Please note that this description is based on .*?\.", "", caption)
caption = re.sub(r", adding to .*? overall appearance", "", caption)
caption = re.sub(r"The rest of .*? is not visible in the image, focusing .*?\.", "", caption)
caption = re.sub(r"hinting at .*?\.", "", caption)
caption = caption.replace(", who is the main subject of the image,", "")
logging.debug(f"**DEBUG: caption: {caption}")
logging.debug(f"**Llava post-cleaning caption: {caption}")
return caption
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
@ -182,6 +190,10 @@ class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
# inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)
# inputs['input_ids'].shape: torch.Size([1, 34])
# inputs['attention_mask'].shape: torch.Size([1, 34])
# inputs['pixel_values'].shape: torch.Size([1, 3, 336, 336])
inputs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs['attention_mask'],
@ -196,7 +208,7 @@ class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner
caption = self.processor.decode(outputs[0][len_inputs:], skip_special_tokens=True)
caption = self._get_full_sentences(caption, args)
caption = self._clean_caption(caption, args)
return caption
class MoaiManager:
@ -312,6 +324,10 @@ class CogVLMManager(BaseModelWrapper):
inputs = self._build_conversation_input_ids(query=prompt, history=history, images=[image], starts_with=args.starts_with)
# inputs['input_ids'].shape: torch.Size([1259])
# inputs['attention_mask'].shape: torch.Size([1259])
# inputs['images'][0].shape: torch.Size([3, 490, 490])
inputs = {
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
"token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"),
@ -320,6 +336,12 @@ class CogVLMManager(BaseModelWrapper):
"output_hidden_states": True,
"return_dict": True
}
# inputs['input_ids'].shape: torch.Size([1, 1259])
# inputs['attention_mask'].shape: torch.Size([1, 1259])
# inputs['images'][0][0].shape: torch.Size([3, 490, 490])
# len(inputs['images'][0]): 1
# len(inputs['images'][0][0]): 3
outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
#print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}")
#print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}")
@ -374,7 +396,6 @@ def main(args):
bad_words = args.bad_words.split(",") if args.bad_words is not None else []
logging.info(f"** bad_words: {Fore.LIGHTGREEN_EX}{bad_words}{Style.RESET_ALL}")
bad_words_ids = model_wrapper.tokenizer(bad_words, add_special_tokens=False)["input_ids"] if bad_words else []
#print(bad_words_ids)
logging.info(f"** gen_kwargs: \n{Fore.LIGHTGREEN_EX}{gen_kwargs}{Style.RESET_ALL}")
@ -475,6 +496,7 @@ DESCRIPTION = f"** {Fore.LIGHTBLUE_EX}CogVLM captioning script{Style.RESET_ALL}
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch processing. Does NOT work with COG! (def: 1)")
argparser.add_argument("--debug", action="store_true", help="Enable debug logging")
argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.")
argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling")