add unicode cleanup to caption_cog and a few more llava cleanups

This commit is contained in:
Victor Hall 2024-05-15 14:23:01 -04:00
parent 1fe3d9f4e5
commit e7201d87db
3 changed files with 58 additions and 27 deletions

View File

@ -31,9 +31,9 @@ import PIL.ImageOps as ImageOps
from pynvml import *
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, AutoProcessor, LlavaProcessor, AutoTokenizer
from transformers.modeling_outputs import BaseModelOutputWithPast
from colorama import Fore, Style
from unidecode import unidecode
from plugins.caption_plugins import load_prompt_alteration_plugin
from utils.patch_cog import patch_cog
@ -70,10 +70,10 @@ def save_params(args, gen_kwargs):
with open(save_path, "w") as f:
f.write(pretty_print)
def create_bnb_config():
def create_bnb_config(bnb_4bit_compute_dtype="bfloat16",bnb_4bit_quant_type= "fp4"):
return BitsAndBytesConfig(
bnb_4bit_compute_dtype="bfloat16",
bnb_4bit_quant_type= "fp4",
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_use_double_quant=False,
llm_int8_enable_fp32_cpu_offload=False,
llm_int8_has_fp16_weight=False,
@ -115,30 +115,29 @@ class BaseModelWrapper:
"length_penalty": args.length_penalty,
}
logging.info(gen_kwargs)
#logging.debug(gen_kwargs)
if args.max_new_tokens is not None:
logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
logging.debug(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
del gen_kwargs["max_length"]
if not gen_kwargs["do_sample"]:
logging.info(f"** Using greedy sampling")
logging.debug(f"** Using greedy sampling")
del gen_kwargs["top_k"]
del gen_kwargs["top_p"]
del gen_kwargs["temperature"]
else:
logging.info(f"** Sampling enabled")
logging.debug(f"** Sampling enabled")
return gen_kwargs
def caption(prompt, args):
return ""
class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers
class XtunerLlavaModelManager(BaseModelWrapper):
def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"):
self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers"
super().__init__(model_name)
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
self.model = LlavaForConditionalGeneration.from_pretrained(
#self.model = AutoModelForCausalLM.from_pretrained(
@ -163,24 +162,41 @@ class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
def _clean_caption(self, caption, args):
"""
Clean up the caption by removing any newlines and excess whitespace, and removes some nonsense Llava adds.
"""
logging.debug(f"**Llava pre-cleaning caption: {caption}")
def _truncate_to_whole_sentences(self, caption):
# model does not stop generating cleanly and cuts off mid sentence
caption = caption.split(".")
#sentence_count = min(4, len(caption))
caption = ". ".join(caption[0:-1]) + "."
caption = caption.replace("\n","")
caption = caption.replace(" "," ")
caption = re.sub(r"The image does not contain .*?\.", "", caption)
caption = re.sub(r"Please note that this description is based on .*?\.", "", caption)
caption = re.sub(r", adding to .*? overall appearance", "", caption)
caption = re.sub(r"The rest of .*? is not visible in the image, focusing .*?\.", "", caption)
caption = re.sub(r"hinting at .*?\.", "", caption)
caption = caption.replace(", who is the main subject of the image,", "")
return caption
logging.debug(f"**Llava post-cleaning caption: {caption}")
def _clean_caption(self, caption, args):
"""
Removes some nonsense Llava adds.
"""
if not args.no_clean:
logging.debug(f"**Llava pre-cleaning caption: {caption}")
caption = caption.replace("**", "")
caption = re.sub(r"The image does not contain .*?\.", "", caption)
caption = re.sub(r"Please note that this description is based on .*?\.", "", caption)
caption = re.sub(r", adding to .*? overall appearance", "", caption)
caption = re.sub(r"The rest of .*? is not visible in the image, focusing .*?\.", "", caption)
caption = re.sub(r", adding to the .*? of the image", "", caption)
caption = re.sub(r", making .*? the focal point of the image", "", caption)
caption = re.sub(r", adding .*? to the scene", "", caption)
caption = re.sub(r", adding an element of .*? to .*?\.",".", caption) # [intrigue, color, etc] .. [the image, the scene, etc]
caption = re.sub(r", hinting at .*?\.", ".", caption)
caption = re.sub(r"hinting at .*?\.", ".", caption)
caption = re.sub(r", .*? is the main subject of the .*?\.",".", caption) # [who, which, etc] .. [image, photo, etc]
caption = re.sub(r", .*? is the main subject of the .*?,",".", caption)
caption = caption.replace(", who is the main subject of the image,", "")
caption = caption.replace(", which is the main subject of the image,", "")
caption = caption.replace(", who is the main subject of the photo.", ".")
caption = caption.replace(", who is the main subject.", ".")
caption = caption.replace("who is the main subject.", ".")
caption = self._truncate_to_whole_sentences(caption)
logging.debug(f"**Llava post-cleaning caption: {caption}")
return caption
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
@ -370,6 +386,9 @@ def get_inputs_dict(inputs):
"return_dict": True
}
def replace_non_utf8_chars(text) -> str:
def main(args):
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)
model_wrapper = get_model_wrapper(args.model)
@ -445,7 +464,10 @@ def main(args):
caption += args.append
with open(candidate_caption_path, "w") as f:
if not args.no_clean:
caption = unidecode(caption)
with open(candidate_caption_path, "w", encoding="utf-8") as f:
f.write(caption)
vram_gb = get_gpu_memory_map()
elapsed_time = time.time() - cap_start_time
@ -512,6 +534,7 @@ if __name__ == "__main__":
argparser.add_argument("--prompt", type=str, default="Write a description.", help="Prompt that will guide captioning")
argparser.add_argument("--image_dir", type=str, default=None, help="Path to folder of images to caption")
argparser.add_argument("--no_overwrite", action="store_true", help="Skips captioning images that already have a caption file.")
argparser.add_argument("--no_clean", action="store_true", help="Skips cleaning of \"junk\" phrases")
argparser.add_argument("--force_words", type=str, default=None, help="Forces the model to include these words in the caption, use CSV format.")
argparser.add_argument("--bad_words", type=str, default=None, help="Words that will not be allowed, use CSV format.")
argparser.add_argument("--append", type=str, default=None, help="Extra string to append to all captions. ex. 'painted by John Doe'")
@ -527,9 +550,15 @@ if __name__ == "__main__":
configure_logging(args, "caption_cog.log")
unknown_args_dict = {}
for i in range(0, len(unknown_args), 2):
key = unknown_args[i].lstrip('-') # Remove the leading '--'
value = unknown_args[i + 1]
print(unknown_args)
print(len(unknown_args))
for i in range(0, len(unknown_args)-1, 1):
key = unknown_args[i].lstrip('-')
if unknown_args[i+1].startswith("-"): # "store_true" instead of a kvp
value = True
else:
value = unknown_args[i+1] # value is next item for all kvp in unknown args
i += 1 # skip over the value of the kvp for next iteration to get next key
unknown_args_dict[key] = value
setattr(args, key, value) # Add each unknown argument to the args namespace

View File

@ -19,3 +19,4 @@ safetensors
prodigyopt
torchsde
peft==0.9.0
unidecode

View File

@ -25,6 +25,7 @@ pip install safetensors
pip install prodigyopt
pip install torchsde
pip install peft>=0.9.0
pip install unidecode
python utils/get_yamls.py
GOTO :eof