add unicode cleanup to caption_cog and a few more llava cleanups
This commit is contained in:
parent
1fe3d9f4e5
commit
e7201d87db
|
@ -31,9 +31,9 @@ import PIL.ImageOps as ImageOps
|
|||
from pynvml import *
|
||||
|
||||
from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, AutoProcessor, LlavaProcessor, AutoTokenizer
|
||||
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from colorama import Fore, Style
|
||||
from unidecode import unidecode
|
||||
|
||||
from plugins.caption_plugins import load_prompt_alteration_plugin
|
||||
from utils.patch_cog import patch_cog
|
||||
|
@ -70,10 +70,10 @@ def save_params(args, gen_kwargs):
|
|||
with open(save_path, "w") as f:
|
||||
f.write(pretty_print)
|
||||
|
||||
def create_bnb_config():
|
||||
def create_bnb_config(bnb_4bit_compute_dtype="bfloat16",bnb_4bit_quant_type= "fp4"):
|
||||
return BitsAndBytesConfig(
|
||||
bnb_4bit_compute_dtype="bfloat16",
|
||||
bnb_4bit_quant_type= "fp4",
|
||||
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
||||
bnb_4bit_quant_type=bnb_4bit_quant_type,
|
||||
bnb_4bit_use_double_quant=False,
|
||||
llm_int8_enable_fp32_cpu_offload=False,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
|
@ -115,30 +115,29 @@ class BaseModelWrapper:
|
|||
"length_penalty": args.length_penalty,
|
||||
}
|
||||
|
||||
logging.info(gen_kwargs)
|
||||
#logging.debug(gen_kwargs)
|
||||
|
||||
if args.max_new_tokens is not None:
|
||||
logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
|
||||
logging.debug(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
|
||||
del gen_kwargs["max_length"]
|
||||
|
||||
if not gen_kwargs["do_sample"]:
|
||||
logging.info(f"** Using greedy sampling")
|
||||
logging.debug(f"** Using greedy sampling")
|
||||
del gen_kwargs["top_k"]
|
||||
del gen_kwargs["top_p"]
|
||||
del gen_kwargs["temperature"]
|
||||
else:
|
||||
logging.info(f"** Sampling enabled")
|
||||
logging.debug(f"** Sampling enabled")
|
||||
return gen_kwargs
|
||||
|
||||
def caption(prompt, args):
|
||||
return ""
|
||||
|
||||
class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers
|
||||
class XtunerLlavaModelManager(BaseModelWrapper):
|
||||
def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"):
|
||||
self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers"
|
||||
super().__init__(model_name)
|
||||
|
||||
|
||||
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
|
||||
self.model = LlavaForConditionalGeneration.from_pretrained(
|
||||
#self.model = AutoModelForCausalLM.from_pretrained(
|
||||
|
@ -163,24 +162,41 @@ class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner
|
|||
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
|
||||
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
|
||||
|
||||
def _clean_caption(self, caption, args):
|
||||
"""
|
||||
Clean up the caption by removing any newlines and excess whitespace, and removes some nonsense Llava adds.
|
||||
"""
|
||||
logging.debug(f"**Llava pre-cleaning caption: {caption}")
|
||||
def _truncate_to_whole_sentences(self, caption):
|
||||
# model does not stop generating cleanly and cuts off mid sentence
|
||||
caption = caption.split(".")
|
||||
#sentence_count = min(4, len(caption))
|
||||
caption = ". ".join(caption[0:-1]) + "."
|
||||
caption = caption.replace("\n","")
|
||||
caption = caption.replace(" "," ")
|
||||
caption = re.sub(r"The image does not contain .*?\.", "", caption)
|
||||
caption = re.sub(r"Please note that this description is based on .*?\.", "", caption)
|
||||
caption = re.sub(r", adding to .*? overall appearance", "", caption)
|
||||
caption = re.sub(r"The rest of .*? is not visible in the image, focusing .*?\.", "", caption)
|
||||
caption = re.sub(r"hinting at .*?\.", "", caption)
|
||||
caption = caption.replace(", who is the main subject of the image,", "")
|
||||
return caption
|
||||
|
||||
logging.debug(f"**Llava post-cleaning caption: {caption}")
|
||||
def _clean_caption(self, caption, args):
|
||||
"""
|
||||
Removes some nonsense Llava adds.
|
||||
"""
|
||||
if not args.no_clean:
|
||||
logging.debug(f"**Llava pre-cleaning caption: {caption}")
|
||||
caption = caption.replace("**", "")
|
||||
caption = re.sub(r"The image does not contain .*?\.", "", caption)
|
||||
caption = re.sub(r"Please note that this description is based on .*?\.", "", caption)
|
||||
caption = re.sub(r", adding to .*? overall appearance", "", caption)
|
||||
caption = re.sub(r"The rest of .*? is not visible in the image, focusing .*?\.", "", caption)
|
||||
caption = re.sub(r", adding to the .*? of the image", "", caption)
|
||||
caption = re.sub(r", making .*? the focal point of the image", "", caption)
|
||||
caption = re.sub(r", adding .*? to the scene", "", caption)
|
||||
caption = re.sub(r", adding an element of .*? to .*?\.",".", caption) # [intrigue, color, etc] .. [the image, the scene, etc]
|
||||
caption = re.sub(r", hinting at .*?\.", ".", caption)
|
||||
caption = re.sub(r"hinting at .*?\.", ".", caption)
|
||||
caption = re.sub(r", .*? is the main subject of the .*?\.",".", caption) # [who, which, etc] .. [image, photo, etc]
|
||||
caption = re.sub(r", .*? is the main subject of the .*?,",".", caption)
|
||||
caption = caption.replace(", who is the main subject of the image,", "")
|
||||
caption = caption.replace(", which is the main subject of the image,", "")
|
||||
caption = caption.replace(", who is the main subject of the photo.", ".")
|
||||
caption = caption.replace(", who is the main subject.", ".")
|
||||
caption = caption.replace("who is the main subject.", ".")
|
||||
caption = self._truncate_to_whole_sentences(caption)
|
||||
|
||||
logging.debug(f"**Llava post-cleaning caption: {caption}")
|
||||
return caption
|
||||
|
||||
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
||||
|
@ -370,6 +386,9 @@ def get_inputs_dict(inputs):
|
|||
"return_dict": True
|
||||
}
|
||||
|
||||
def replace_non_utf8_chars(text) -> str:
|
||||
|
||||
|
||||
def main(args):
|
||||
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)
|
||||
model_wrapper = get_model_wrapper(args.model)
|
||||
|
@ -445,7 +464,10 @@ def main(args):
|
|||
|
||||
caption += args.append
|
||||
|
||||
with open(candidate_caption_path, "w") as f:
|
||||
if not args.no_clean:
|
||||
caption = unidecode(caption)
|
||||
|
||||
with open(candidate_caption_path, "w", encoding="utf-8") as f:
|
||||
f.write(caption)
|
||||
vram_gb = get_gpu_memory_map()
|
||||
elapsed_time = time.time() - cap_start_time
|
||||
|
@ -512,6 +534,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--prompt", type=str, default="Write a description.", help="Prompt that will guide captioning")
|
||||
argparser.add_argument("--image_dir", type=str, default=None, help="Path to folder of images to caption")
|
||||
argparser.add_argument("--no_overwrite", action="store_true", help="Skips captioning images that already have a caption file.")
|
||||
argparser.add_argument("--no_clean", action="store_true", help="Skips cleaning of \"junk\" phrases")
|
||||
argparser.add_argument("--force_words", type=str, default=None, help="Forces the model to include these words in the caption, use CSV format.")
|
||||
argparser.add_argument("--bad_words", type=str, default=None, help="Words that will not be allowed, use CSV format.")
|
||||
argparser.add_argument("--append", type=str, default=None, help="Extra string to append to all captions. ex. 'painted by John Doe'")
|
||||
|
@ -527,9 +550,15 @@ if __name__ == "__main__":
|
|||
configure_logging(args, "caption_cog.log")
|
||||
|
||||
unknown_args_dict = {}
|
||||
for i in range(0, len(unknown_args), 2):
|
||||
key = unknown_args[i].lstrip('-') # Remove the leading '--'
|
||||
value = unknown_args[i + 1]
|
||||
print(unknown_args)
|
||||
print(len(unknown_args))
|
||||
for i in range(0, len(unknown_args)-1, 1):
|
||||
key = unknown_args[i].lstrip('-')
|
||||
if unknown_args[i+1].startswith("-"): # "store_true" instead of a kvp
|
||||
value = True
|
||||
else:
|
||||
value = unknown_args[i+1] # value is next item for all kvp in unknown args
|
||||
i += 1 # skip over the value of the kvp for next iteration to get next key
|
||||
unknown_args_dict[key] = value
|
||||
setattr(args, key, value) # Add each unknown argument to the args namespace
|
||||
|
||||
|
|
|
@ -19,3 +19,4 @@ safetensors
|
|||
prodigyopt
|
||||
torchsde
|
||||
peft==0.9.0
|
||||
unidecode
|
|
@ -25,6 +25,7 @@ pip install safetensors
|
|||
pip install prodigyopt
|
||||
pip install torchsde
|
||||
pip install peft>=0.9.0
|
||||
pip install unidecode
|
||||
python utils/get_yamls.py
|
||||
GOTO :eof
|
||||
|
||||
|
|
Loading…
Reference in New Issue