diff --git a/caption_cog.py b/caption_cog.py index d001584..cae4d4a 100644 --- a/caption_cog.py +++ b/caption_cog.py @@ -36,7 +36,12 @@ from colorama import Fore, Style from plugins.caption_plugins import load_prompt_alteration_plugin from utils.patch_cog import patch_cog -from data.gen_utils import image_generator, SUPPORTED_EXT +from data.generators import image_path_generator, SUPPORTED_EXT + +try: + from moai.load_moai import prepare_moai +except ImportError: + print("moai not found, skipping") IMAGE_SIZE: int = 490 PATCH_SIZE: int = 14 @@ -107,7 +112,7 @@ def save_params(args, gen_kwargs): with open(save_path, "w") as f: f.write(pretty_print) -def create_bnb_config(args): +def create_bnb_config(): return BitsAndBytesConfig( bnb_4bit_compute_dtype="float32", bnb_4bit_quant_type= "fp4", @@ -121,58 +126,117 @@ def create_bnb_config(args): quant_method="bitsandbytes" ) +class MoaiManager: + def __init__(self, model_name: str): + self.model_name = model_name + self.moai_model = None + self.moai_processor = None + self.seg_model = None + self.seg_processor = None + self.od_model = None + self.od_processor = None + self.sgg_model = None + self.ocr_model = None + + def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str='fp16'): + moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \ + = prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype) + self.moai_model = moai_model + self.moai_processor = moai_processor + self.seg_model = seg_model + self.seg_processor = seg_processor + self.od_model = od_model + self.od_processor = od_processor + self.sgg_model = sgg_model + self.ocr_model = ocr_model + + return moai_model, moai_processor + + def get_inputs(self, image: Image.Image, prompt: str): + moai_inputs = self.moai_model.demo_process(image=image, + prompt=prompt, + processor=self.moai_processor, + seg_model=self.seg_model, + seg_processor=self.seg_processor, + od_model=self.od_model, + od_processor=self.od_processor, + sgg_model=self.sgg_model, + ocr_model=self.ocr_model, + device='cuda:0') + return moai_inputs + + def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any: + with torch.inference_mode(): + generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache) + answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split('[U')[0] + return answer + +class CogVLMManager: + def __init__(self, model_name: str): + self.model_name = model_name + self.tokenizer = None + self.model = None + + def load_model(self): + self.tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') + self.model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + quantization_config=create_bnb_config() + ) + return self.model, self.tokenizer + + def get_inputs(self, prompt: str, history: List[Tuple[str, str]], images: List[Image.Image], starts_with: str): + return build_conversation_input_ids(self.tokenizer, query=prompt, history=history, images=images, starts_with=starts_with) + + def get_gen_kwargs(self, args): + gen_kwargs = { + "max_length": args.max_length, + "do_sample": args.top_k is not None or args.top_p is not None or args.temp is not None or False, + "length_penalty": args.length_penalty, + "num_beams": args.num_beams, + "temperature": args.temp, + "top_k": args.top_k, + "top_p": args.top_p, + "repetition_penalty": args.repetition_penalty, + "no_repeat_ngram_size": args.no_repeat_ngram_size, + "min_new_tokens": args.min_new_tokens, + "max_new_tokens": args.max_new_tokens, + "length_penalty": args.length_penalty, + } + print(gen_kwargs) + if args.max_new_tokens is not None: + logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length") + del gen_kwargs["max_length"] + + if not gen_kwargs["do_sample"]: + logging.info(f"** Using greedy sampling") + del gen_kwargs["top_k"] + del gen_kwargs["top_p"] + del gen_kwargs["temperature"] + else: + logging.info(f"** Sampling enabled") + return gen_kwargs + +def model_manager_factory(model_name: str): + if "moai" in model_name: + return MoaiManager(model_name) + else: + return CogVLMManager(model_name) + def main(args): prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args) + model_manager = model_manager_factory(args.model) - bnb_config = create_bnb_config(args) - - tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5') - model = AutoModelForCausalLM.from_pretrained( - 'THUDM/cogvlm-chat-hf', - torch_dtype=torch.bfloat16, - low_cpu_mem_usage=True, - trust_remote_code=True, # gee hope they don't get hacked or have a bad internal actor - #revision=... # no one is actually doing this - #load_in_4bit=not args.disable_4bit, - quantization_config=bnb_config, - ) - - do_sample = args.top_k is not None or args.top_p is not None or args.temp is not None - if do_sample: - args.top_k = args.top_k or 50 - args.top_p = args.top_p or 1.0 - args.temp = args.temp or 1.0 + model, tokenizer = model_manager.load_model() args.append = args.append or "" if len(args.append) > 0: args.append = " " + args.append.strip() - gen_kwargs = { - "max_length": args.max_length, - "do_sample": do_sample, - "length_penalty": args.length_penalty, - "num_beams": args.num_beams, - "temperature": args.temp, - "top_k": args.top_k, - "top_p": args.top_p, - "repetition_penalty": args.repetition_penalty, - "no_repeat_ngram_size": args.no_repeat_ngram_size, - "min_new_tokens": args.min_new_tokens, - "max_new_tokens": args.max_new_tokens, - "length_penalty": args.length_penalty, - } - - if args.max_new_tokens is not None: - logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length") - del gen_kwargs["max_length"] - - if not do_sample: - logging.info(f"** Using greedy sampling") - del gen_kwargs["top_k"] - del gen_kwargs["top_p"] - del gen_kwargs["temperature"] - else: - logging.info(f"** Sampling enabled") + gen_kwargs = model_manager.get_gen_kwargs(args) force_words_ids = None if args.force_words is not None: @@ -195,7 +259,7 @@ def main(args): starts_with = args.starts_with.strip() if args.starts_with is not None else "" - for i, image_path in enumerate(image_generator(args.image_dir, do_recurse=not args.no_recurse)): + for i, image_path in enumerate(image_path_generator(args.image_dir, do_recurse=not args.no_recurse)): candidate_caption_path = image_path.replace(os.path.splitext(image_path)[-1], ".txt") if args.no_overwrite and os.path.exists(candidate_caption_path): @@ -342,6 +406,7 @@ if __name__ == "__main__": argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.") argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.") argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.") + argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.") args = argparser.parse_args() configure_logging(args) diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/data/gen_utils.py b/data/gen_utils.py deleted file mode 100644 index 404bcbf..0000000 --- a/data/gen_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import os -from typing import Generator - -SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"] - -def image_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]: - if do_recurse: - for root, dirs, files in os.walk(image_dir): - for file in files: - if any(file.endswith(ext) for ext in SUPPORTED_EXT): - yield os.path.join(root, file) - else: - for file in os.listdir(image_dir): - if any(file.endswith(ext) for ext in SUPPORTED_EXT): - yield os.path.join(image_dir, file) \ No newline at end of file diff --git a/data/generators.py b/data/generators.py new file mode 100644 index 0000000..da5be51 --- /dev/null +++ b/data/generators.py @@ -0,0 +1,96 @@ +""" +Copyright [2022-2024] Victor C Hall + +Licensed under the GNU Affero General Public License; +You may not use this code except in compliance with the License. +You may obtain a copy of the License at + + https://www.gnu.org/licenses/agpl-3.0.en.html + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +from typing import Generator +from data.image_train_item import ImageTrainItem, ImageCaption +from PIL import Image, ImageOps +import tarfile +import logging + +SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"] + +class BucketBatchedGenerator(Generator[ImageTrainItem, None, None]): + """ + returns items in with the same aspect ratio in batches, for use with batching dataloaders + """ + def __init__(self, batch_size: int=1, generator: Generator[ImageTrainItem, None, None]=None): + self.caption = batch_size + self.cache = {} + self.generator = generator + + def __iter__(self): + for item in self.generator: + if item.target_wh: + aspect_bucket_key = item.target_wh + if aspect_bucket_key not in self.cache: + self.cache[aspect_bucket_key] = [] + self.cache[aspect_bucket_key].append(item) + if len(self.cache[aspect_bucket_key]) >= self.batch_size: + for item in self.cache[aspect_bucket_key]: + yield item + self.cache[aspect_bucket_key] = [] + +# def image_train_item_generator_from_tar_pairs(image_dir: str, do_recurse: bool = True) -> Generator[ImageTrainItem, None, None]: +# for root, dirs, files in os.walk(image_dir): +# for file in files: +# if file.endswith(".tar"): +# tar_path = os.path.join(root, file) +# with tarfile.open(tar_path, "r") as tar: +# for tarinfo in tar: +# if tarinfo.isfile() and any(tarinfo.name.endswith(ext) for ext in SUPPORTED_EXT): +# try: +# img = Image.open(tar.extractfile(tarinfo)) +# txt = tar.extractfile(tarinfo.name.replace(os.path.splitext(tarinfo.name)[-1], ".txt")) +# caption = txt.read().decode("utf-8") +# img_caption = ImageCaption(main_prompt=caption, rating=0, tags=[], tag_weights=[], max_target_length=256, use_weights=False) +# img = ImageOps.exif_transpose(img) +# iti = ImageTrainItem(img, img_caption) +# except Exception as e: +# logging.error(f"Failed to open {tarinfo.name}: {e}") +# continue +# yield iti + +def image_train_item_generator_from_files(image_dir: str, do_recurse: bool = True) -> Generator[ImageTrainItem, None, None]: + for img_path in image_path_generator(image_dir, do_recurse): + try: + img = Image.open(img_path) + img = ImageOps.exif_transpose(img) + except Exception as e: + print(f"Failed to open {img_path}: {e}") + continue + # main_prompt: str, rating: float, tags: list[str], tag_weights: list[float], max_target_length: int, use_weights: bool): + txt_cap_path = img_path.replace(os.path.splitext(img_path)[-1], ".txt") + if os.path.exists(txt_cap_path): + with open(txt_cap_path, "r") as f: + caption = f.read() + if not caption or len(caption) < 1: + caption = os.path.basename(img_path) + caption = caption.split("_")[0] + image_caption = ImageCaption(main_prompt=caption, rating=0, tags=[], tag_weights=[], max_target_length=128, use_weights=False) + iti = ImageTrainItem(img) + yield iti + +def image_path_generator(image_dir: str, do_recurse: bool = True) -> Generator[str, None, None]: + if do_recurse: + for root, dirs, files in os.walk(image_dir): + for file in files: + if any(file.endswith(ext) for ext in SUPPORTED_EXT): + yield os.path.join(root, file) + else: + for file in os.listdir(image_dir): + if any(file.endswith(ext) for ext in SUPPORTED_EXT): + yield os.path.join(image_dir, file) \ No newline at end of file diff --git a/plugins/caption_plugins.py b/plugins/caption_plugins.py index b323d98..7592c54 100644 --- a/plugins/caption_plugins.py +++ b/plugins/caption_plugins.py @@ -22,7 +22,7 @@ class TestSub(TestBase): def __repr__(self) -> str: return f"TestSub: {self.a}, {self.b}" -class PromptIdentityPlugin(): +class PromptIdentityBase(): """ Base class for prompt alternation plugins, useful for captioning, etc. """ @@ -68,7 +68,7 @@ class PromptIdentityPlugin(): prompt = f"Hint: {hint}\n{prompt}" return prompt -class HintFromFilename(PromptIdentityPlugin): +class HintFromFilename(PromptIdentityBase): def __init__(self, args:Namespace=None): super().__init__(key="hint_from_filename", description="Add a hint to the prompt using the filename of the image (without extension)", @@ -81,7 +81,7 @@ class HintFromFilename(PromptIdentityPlugin): prompt = self._add_hint_to_prompt(filename, prompt) return prompt -class RemoveUsingCSV(PromptIdentityPlugin): +class RemoveUsingCSV(PromptIdentityBase): def __init__(self, args:Namespace=None): super().__init__(key="remove_using_csv", description="Removes whole word matches of the csv passed in from the prompt", @@ -111,7 +111,7 @@ class RemoveUsingCSV(PromptIdentityPlugin): prompt = self._filter_logic(prompt, [word]) return prompt -class HintFromLeafDirectory(PromptIdentityPlugin): +class HintFromLeafDirectory(PromptIdentityBase): def __init__(self, args:Namespace=None): super().__init__(key="from_leaf_directory", description="Adds a hint to the prompt using the leaf directory name (last folder in path)", @@ -130,13 +130,13 @@ class MetadataProvider(): self._datadict = {} def _from_metadata(self, args) -> dict: - image_path = args.get("image_path", "") + image_path = args.image_path prompt = args.get("prompt", "") metadata = self._get_metadata_dict(image_path) return f"metadata: {metadata}\n{prompt}" def _get_metadata_dict(self, metadata_path: str) -> dict: - if not self.loaded and not metadata_path in self.cache: + if not metadata_path in self._datadict: metadata_dirname = os.path.dirname(metadata_path) if not os.path.exists(metadata_path): logging.warning(f" metadata.json not found in {metadata_dirname}, ignoring{Style.RESET_ALL}") @@ -145,12 +145,12 @@ class MetadataProvider(): metadata = json.load(f) self._datadict[metadata_path] = metadata - return self.dict[metadata_path] + return self._datadict[metadata_path] -class FromFolderMetadataJson(PromptIdentityPlugin): +class FromFolderMetadataJson(PromptIdentityBase): def __init__(self, args:Namespace=None): super().__init__(key="from_folder_metadata", - description="Looks for metadata.json in the folder of the images", + description="Looks for metadata.json in the folder of the images and prefixes it to the prompt", fn=self._from_metadata_json, args=args) self.metadata_provider = MetadataProvider() @@ -159,11 +159,12 @@ class FromFolderMetadataJson(PromptIdentityPlugin): image_path = args.image_path image_dir = os.path.dirname(image_path) metadata_json_path = os.path.join(image_dir, "metadata.json") - self.metadata_provider._get_metadata_dict(metadata_json_path) + metadata = self.metadata_provider._get_metadata_dict(metadata_json_path) + metadata = json.dumps(metadata, indent=2) + prompt = self._add_hint_to_prompt(f"metadata: {metadata}", args.prompt) + return prompt - return "" - -class TagsFromFolderMetadataJson(PromptIdentityPlugin): +class TagsFromFolderMetadataJson(PromptIdentityBase): def __init__(self, args:Namespace=None): self.cache = {} super().__init__(key = "tags_from_metadata_json", @@ -185,7 +186,7 @@ class TagsFromFolderMetadataJson(PromptIdentityPlugin): return self._add_hint_to_prompt(f"tags: {tags}", prompt) return prompt -class TitleAndTagsFromFolderImageJson(PromptIdentityPlugin): +class TitleAndTagsFromImageJson(PromptIdentityBase): def __init__(self, args:Namespace=None): super().__init__(key="title_and_tags_from_image_json", description="Adds title and tags hint from metadata.json (in the samefolder as the image) to the prompt", @@ -218,7 +219,7 @@ class TitleAndTagsFromFolderImageJson(PromptIdentityPlugin): logging.debug(f" {self.key}: prompt after: {prompt}") return prompt -class TitleAndTagsFromFolderMetadataJson(PromptIdentityPlugin): +class TitleAndTagsFromFolderMetadataJson(PromptIdentityBase): def __init__(self, args:Namespace=None): self.cache = {} super().__init__(key="title_and_tags_from_metadata_json", @@ -254,7 +255,7 @@ class TitleAndTagsFromFolderMetadataJson(PromptIdentityPlugin): logging.debug(f" {self.key}: prompt after: {prompt}") return prompt -class TitleAndTagsFromGlobalMetadataJson(PromptIdentityPlugin): +class TitleAndTagsFromGlobalMetadataJson(PromptIdentityBase): """ Adds title and tags hint from global metadata json given by '--metadatafilename' Note: you could just put your metadata in the prompt instead of using this plugin, but perhaps useful? @@ -316,8 +317,8 @@ def get_prompt_alteration_plugin_list() -> list: if isinstance(attribute, type) \ and attribute.__module__ == module.__name__ \ - and is_subclass_of_subclass(attribute, PromptIdentityPlugin, recursion_depth=5) \ - and attribute is not PromptIdentityPlugin: + and is_subclass_of_subclass(attribute, PromptIdentityBase, recursion_depth=5) \ + and attribute is not PromptIdentityBase: plugins.append(attribute) #print(f"done checking plugins_module_name: {plugins_module_name}") @@ -337,4 +338,4 @@ def load_prompt_alteration_plugin(plugin_key: str, args) -> callable: raise ValueError(f"plugin_key: {plugin_key} not found in prompt_alteration_plugins") else: logging.info(f"No plugin specified") - return PromptIdentityPlugin(args=args) + return PromptIdentityBase(args=args)