some cleanup/updates to caption stuff

This commit is contained in:
Victor Hall 2024-03-22 13:27:01 -04:00
parent 273ce23c7c
commit 056de840d0
5 changed files with 228 additions and 81 deletions

View File

@ -36,7 +36,12 @@ from colorama import Fore, Style
from plugins.caption_plugins import load_prompt_alteration_plugin
from utils.patch_cog import patch_cog
from data.gen_utils import image_generator, SUPPORTED_EXT
from data.generators import image_path_generator, SUPPORTED_EXT
try:
from moai.load_moai import prepare_moai
except ImportError:
print("moai not found, skipping")
IMAGE_SIZE: int = 490
PATCH_SIZE: int = 14
@ -107,7 +112,7 @@ def save_params(args, gen_kwargs):
with open(save_path, "w") as f:
f.write(pretty_print)
def create_bnb_config(args):
def create_bnb_config():
return BitsAndBytesConfig(
bnb_4bit_compute_dtype="float32",
bnb_4bit_quant_type= "fp4",
@ -121,58 +126,117 @@ def create_bnb_config(args):
quant_method="bitsandbytes"
)
class MoaiManager:
def __init__(self, model_name: str):
self.model_name = model_name
self.moai_model = None
self.moai_processor = None
self.seg_model = None
self.seg_processor = None
self.od_model = None
self.od_processor = None
self.sgg_model = None
self.ocr_model = None
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str='fp16'):
moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
= prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
self.moai_model = moai_model
self.moai_processor = moai_processor
self.seg_model = seg_model
self.seg_processor = seg_processor
self.od_model = od_model
self.od_processor = od_processor
self.sgg_model = sgg_model
self.ocr_model = ocr_model
return moai_model, moai_processor
def get_inputs(self, image: Image.Image, prompt: str):
moai_inputs = self.moai_model.demo_process(image=image,
prompt=prompt,
processor=self.moai_processor,
seg_model=self.seg_model,
seg_processor=self.seg_processor,
od_model=self.od_model,
od_processor=self.od_processor,
sgg_model=self.sgg_model,
ocr_model=self.ocr_model,
device='cuda:0')
return moai_inputs
def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any:
with torch.inference_mode():
generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache)
answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split('[U')[0]
return answer
class CogVLMManager:
def __init__(self, model_name: str):
self.model_name = model_name
self.tokenizer = None
self.model = None
def load_model(self):
self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
quantization_config=create_bnb_config()
)
return self.model, self.tokenizer
def get_inputs(self, prompt: str, history: List[Tuple[str, str]], images: List[Image.Image], starts_with: str):
return build_conversation_input_ids(self.tokenizer, query=prompt, history=history, images=images, starts_with=starts_with)
def get_gen_kwargs(self, args):
gen_kwargs = {
"max_length": args.max_length,
"do_sample": args.top_k is not None or args.top_p is not None or args.temp is not None or False,
"length_penalty": args.length_penalty,
"num_beams": args.num_beams,
"temperature": args.temp,
"top_k": args.top_k,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"no_repeat_ngram_size": args.no_repeat_ngram_size,
"min_new_tokens": args.min_new_tokens,
"max_new_tokens": args.max_new_tokens,
"length_penalty": args.length_penalty,
}
print(gen_kwargs)
if args.max_new_tokens is not None:
logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
del gen_kwargs["max_length"]
if not gen_kwargs["do_sample"]:
logging.info(f"** Using greedy sampling")
del gen_kwargs["top_k"]
del gen_kwargs["top_p"]
del gen_kwargs["temperature"]
else:
logging.info(f"** Sampling enabled")
return gen_kwargs
def model_manager_factory(model_name: str):
if "moai" in model_name:
return MoaiManager(model_name)
else:
return CogVLMManager(model_name)
def main(args):
prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args)
model_manager = model_manager_factory(args.model)
bnb_config = create_bnb_config(args)
tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
model = AutoModelForCausalLM.from_pretrained(
'THUDM/cogvlm-chat-hf',
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor
#revision=... # no one is actually doing this
#load_in_4bit=not args.disable_4bit,
quantization_config=bnb_config,
)
do_sample = args.top_k is not None or args.top_p is not None or args.temp is not None
if do_sample:
args.top_k = args.top_k or 50
args.top_p = args.top_p or 1.0
args.temp = args.temp or 1.0
model, tokenizer = model_manager.load_model()
args.append = args.append or ""
if len(args.append) > 0:
args.append = " " + args.append.strip()
gen_kwargs = {
"max_length": args.max_length,
"do_sample": do_sample,
"length_penalty": args.length_penalty,
"num_beams": args.num_beams,
"temperature": args.temp,
"top_k": args.top_k,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
"no_repeat_ngram_size": args.no_repeat_ngram_size,
"min_new_tokens": args.min_new_tokens,
"max_new_tokens": args.max_new_tokens,
"length_penalty": args.length_penalty,
}
if args.max_new_tokens is not None:
logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length")
del gen_kwargs["max_length"]
if not do_sample:
logging.info(f"** Using greedy sampling")
del gen_kwargs["top_k"]
del gen_kwargs["top_p"]
del gen_kwargs["temperature"]
else:
logging.info(f"** Sampling enabled")
gen_kwargs = model_manager.get_gen_kwargs(args)
force_words_ids = None
if args.force_words is not None:
@ -195,7 +259,7 @@ def main(args):
starts_with = args.starts_with.strip() if args.starts_with is not None else ""
for i, image_path in enumerate(image_generator(args.image_dir, do_recurse=not args.no_recurse)):
for i, image_path in enumerate(image_path_generator(args.image_dir, do_recurse=not args.no_recurse)):
candidate_caption_path = image_path.replace(os.path.splitext(image_path)[-1], ".txt")
if args.no_overwrite and os.path.exists(candidate_caption_path):
@ -342,6 +406,7 @@ if __name__ == "__main__":
argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.")
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.")
args = argparser.parse_args()
configure_logging(args)

0
data/__init__.py Normal file
View File

View File

@ -1,15 +0,0 @@
import os
from typing import Generator
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
if do_recurse:
for root, dirs, files in os.walk(image_dir):
for file in files:
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(root, file)
else:
for file in os.listdir(image_dir):
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(image_dir, file)

96
data/generators.py Normal file
View File

@ -0,0 +1,96 @@
"""
Copyright [2022-2024] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from typing import Generator
from data.image_train_item import ImageTrainItem, ImageCaption
from PIL import Image, ImageOps
import tarfile
import logging
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
class BucketBatchedGenerator(Generator[ImageTrainItem, None, None]):
"""
returns items in with the same aspect ratio in batches, for use with batching dataloaders
"""
def __init__(self, batch_size: int=1, generator: Generator[ImageTrainItem, None, None]=None):
self.caption = batch_size
self.cache = {}
self.generator = generator
def __iter__(self):
for item in self.generator:
if item.target_wh:
aspect_bucket_key = item.target_wh
if aspect_bucket_key not in self.cache:
self.cache[aspect_bucket_key] = []
self.cache[aspect_bucket_key].append(item)
if len(self.cache[aspect_bucket_key]) >= self.batch_size:
for item in self.cache[aspect_bucket_key]:
yield item
self.cache[aspect_bucket_key] = []
# def image_train_item_generator_from_tar_pairs(image_dir: str, do_recurse: bool = True) -> Generator[ImageTrainItem, None, None]:
# for root, dirs, files in os.walk(image_dir):
# for file in files:
# if file.endswith(".tar"):
# tar_path = os.path.join(root, file)
# with tarfile.open(tar_path, "r") as tar:
# for tarinfo in tar:
# if tarinfo.isfile() and any(tarinfo.name.endswith(ext) for ext in SUPPORTED_EXT):
# try:
# img = Image.open(tar.extractfile(tarinfo))
# txt = tar.extractfile(tarinfo.name.replace(os.path.splitext(tarinfo.name)[-1], ".txt"))
# caption = txt.read().decode("utf-8")
# img_caption = ImageCaption(main_prompt=caption, rating=0, tags=[], tag_weights=[], max_target_length=256, use_weights=False)
# img = ImageOps.exif_transpose(img)
# iti = ImageTrainItem(img, img_caption)
# except Exception as e:
# logging.error(f"Failed to open {tarinfo.name}: {e}")
# continue
# yield iti
def image_train_item_generator_from_files(image_dir: str, do_recurse: bool = True) -> Generator[ImageTrainItem, None, None]:
for img_path in image_path_generator(image_dir, do_recurse):
try:
img = Image.open(img_path)
img = ImageOps.exif_transpose(img)
except Exception as e:
print(f"Failed to open {img_path}: {e}")
continue
# main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool):
txt_cap_path = img_path.replace(os.path.splitext(img_path)[-1], ".txt")
if os.path.exists(txt_cap_path):
with open(txt_cap_path, "r") as f:
caption = f.read()
if not caption or len(caption) < 1:
caption = os.path.basename(img_path)
caption = caption.split("_")[0]
image_caption = ImageCaption(main_prompt=caption, rating=0, tags=[], tag_weights=[], max_target_length=128, use_weights=False)
iti = ImageTrainItem(img)
yield iti
def image_path_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]:
if do_recurse:
for root, dirs, files in os.walk(image_dir):
for file in files:
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(root, file)
else:
for file in os.listdir(image_dir):
if any(file.endswith(ext) for ext in SUPPORTED_EXT):
yield os.path.join(image_dir, file)

View File

@ -22,7 +22,7 @@ class TestSub(TestBase):
def __repr__(self) -> str:
return f"TestSub: {self.a}, {self.b}"
class PromptIdentityPlugin():
class PromptIdentityBase():
"""
Base class for prompt alternation plugins, useful for captioning, etc.
"""
@ -68,7 +68,7 @@ class PromptIdentityPlugin():
prompt = f"Hint: {hint}\n{prompt}"
return prompt
class HintFromFilename(PromptIdentityPlugin):
class HintFromFilename(PromptIdentityBase):
def __init__(self, args:Namespace=None):
super().__init__(key="hint_from_filename",
description="Add a hint to the prompt using the filename of the image (without extension)",
@ -81,7 +81,7 @@ class HintFromFilename(PromptIdentityPlugin):
prompt = self._add_hint_to_prompt(filename, prompt)
return prompt
class RemoveUsingCSV(PromptIdentityPlugin):
class RemoveUsingCSV(PromptIdentityBase):
def __init__(self, args:Namespace=None):
super().__init__(key="remove_using_csv",
description="Removes whole word matches of the csv passed in from the prompt",
@ -111,7 +111,7 @@ class RemoveUsingCSV(PromptIdentityPlugin):
prompt = self._filter_logic(prompt, [word])
return prompt
class HintFromLeafDirectory(PromptIdentityPlugin):
class HintFromLeafDirectory(PromptIdentityBase):
def __init__(self, args:Namespace=None):
super().__init__(key="from_leaf_directory",
description="Adds a hint to the prompt using the leaf directory name (last folder in path)",
@ -130,13 +130,13 @@ class MetadataProvider():
self._datadict = {}
def _from_metadata(self, args) -> dict:
image_path = args.get("image_path", "")
image_path = args.image_path
prompt = args.get("prompt", "")
metadata = self._get_metadata_dict(image_path)
return f"metadata: {metadata}\n{prompt}"
def _get_metadata_dict(self, metadata_path: str) -> dict:
if not self.loaded and not metadata_path in self.cache:
if not metadata_path in self._datadict:
metadata_dirname = os.path.dirname(metadata_path)
if not os.path.exists(metadata_path):
logging.warning(f" metadata.json not found in {metadata_dirname}, ignoring{Style.RESET_ALL}")
@ -145,12 +145,12 @@ class MetadataProvider():
metadata = json.load(f)
self._datadict[metadata_path] = metadata
return self.dict[metadata_path]
return self._datadict[metadata_path]
class FromFolderMetadataJson(PromptIdentityPlugin):
class FromFolderMetadataJson(PromptIdentityBase):
def __init__(self, args:Namespace=None):
super().__init__(key="from_folder_metadata",
description="Looks for metadata.json in the folder of the images",
description="Looks for metadata.json in the folder of the images and prefixes it to the prompt",
fn=self._from_metadata_json,
args=args)
self.metadata_provider = MetadataProvider()
@ -159,11 +159,12 @@ class FromFolderMetadataJson(PromptIdentityPlugin):
image_path = args.image_path
image_dir = os.path.dirname(image_path)
metadata_json_path = os.path.join(image_dir, "metadata.json")
self.metadata_provider._get_metadata_dict(metadata_json_path)
metadata = self.metadata_provider._get_metadata_dict(metadata_json_path)
metadata = json.dumps(metadata, indent=2)
prompt = self._add_hint_to_prompt(f"metadata: {metadata}", args.prompt)
return prompt
return ""
class TagsFromFolderMetadataJson(PromptIdentityPlugin):
class TagsFromFolderMetadataJson(PromptIdentityBase):
def __init__(self, args:Namespace=None):
self.cache = {}
super().__init__(key = "tags_from_metadata_json",
@ -185,7 +186,7 @@ class TagsFromFolderMetadataJson(PromptIdentityPlugin):
return self._add_hint_to_prompt(f"tags: {tags}", prompt)
return prompt
class TitleAndTagsFromFolderImageJson(PromptIdentityPlugin):
class TitleAndTagsFromImageJson(PromptIdentityBase):
def __init__(self, args:Namespace=None):
super().__init__(key="title_and_tags_from_image_json",
description="Adds title and tags hint from metadata.json (in the samefolder as the image) to the prompt",
@ -218,7 +219,7 @@ class TitleAndTagsFromFolderImageJson(PromptIdentityPlugin):
logging.debug(f" {self.key}: prompt after: {prompt}")
return prompt
class TitleAndTagsFromFolderMetadataJson(PromptIdentityPlugin):
class TitleAndTagsFromFolderMetadataJson(PromptIdentityBase):
def __init__(self, args:Namespace=None):
self.cache = {}
super().__init__(key="title_and_tags_from_metadata_json",
@ -254,7 +255,7 @@ class TitleAndTagsFromFolderMetadataJson(PromptIdentityPlugin):
logging.debug(f" {self.key}: prompt after: {prompt}")
return prompt
class TitleAndTagsFromGlobalMetadataJson(PromptIdentityPlugin):
class TitleAndTagsFromGlobalMetadataJson(PromptIdentityBase):
"""
Adds title and tags hint from global metadata json given by '--metadatafilename'
Note: you could just put your metadata in the prompt instead of using this plugin, but perhaps useful?
@ -316,8 +317,8 @@ def get_prompt_alteration_plugin_list() -> list:
if isinstance(attribute, type) \
and attribute.__module__ == module.__name__ \
and is_subclass_of_subclass(attribute, PromptIdentityPlugin, recursion_depth=5) \
and attribute is not PromptIdentityPlugin:
and is_subclass_of_subclass(attribute, PromptIdentityBase, recursion_depth=5) \
and attribute is not PromptIdentityBase:
plugins.append(attribute)
#print(f"done checking plugins_module_name: {plugins_module_name}")
@ -337,4 +338,4 @@ def load_prompt_alteration_plugin(plugin_key: str, args) -> callable:
raise ValueError(f"plugin_key: {plugin_key} not found in prompt_alteration_plugins")
else:
logging.info(f"No plugin specified")
return PromptIdentityPlugin(args=args)
return PromptIdentityBase(args=args)