From fb1caf52a73b30e91bf1d6c7c36a600d5ec4cf76 Mon Sep 17 00:00:00 2001 From: Victor Hall Date: Sat, 4 May 2024 22:24:17 -0400 Subject: [PATCH] add llava captioning --- caption_cog.py | 377 +++++++++++++++++++++++-------------- doc/ADVANCED_TWEAKING.md | 8 +- plugins/caption_plugins.py | 48 +++++ utils/ed_logging.py | 17 ++ 4 files changed, 304 insertions(+), 146 deletions(-) create mode 100644 utils/ed_logging.py diff --git a/caption_cog.py b/caption_cog.py index 8277234..9a24548 100644 --- a/caption_cog.py +++ b/caption_cog.py @@ -30,12 +30,14 @@ from PIL import Image import PIL.ImageOps as ImageOps from pynvml import * -from transformers import AutoModelForCausalLM, LlamaTokenizer, PreTrainedTokenizer, BitsAndBytesConfig +from transformers import AutoModelForCausalLM, LlamaTokenizer, BitsAndBytesConfig, LlavaForConditionalGeneration, AutoProcessor, LlavaProcessor, AutoTokenizer + from transformers.modeling_outputs import BaseModelOutputWithPast from colorama import Fore, Style from plugins.caption_plugins import load_prompt_alteration_plugin from utils.patch_cog import patch_cog +from utils.ed_logging import configure_logging from data.generators import image_path_generator, SUPPORTED_EXT try: @@ -48,54 +50,8 @@ Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit IMAGE_SIZE: int = 490 PATCH_SIZE: int = 14 -patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions - -def build_conversation_input_ids( - tokenizer: PreTrainedTokenizer, - *, - query: str, - history: Optional[List[Tuple[str, str]]] = None, - images: Optional[List[Image.Image]] = None, - starts_with: Optional[str] = None, - ): - # based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py - image_size: int = IMAGE_SIZE - patch_size: int = PATCH_SIZE - assert images is None or len(images) <= 1, f"not support multi images by now." - history = history or [] - - text = f"Question: {query} Answer: " - text += starts_with if starts_with is not None else "" - - input_ids = [tokenizer.bos_token_id] - token_type_ids = [0] - if images is not None and len(images) == 1: - # vision - transform = transforms.Compose( - [ - transforms.Resize( - (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC - ), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - ] - ) - images = [transform(images[0])] - vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2 - input_ids += [tokenizer.pad_token_id] * vision_token_num - token_type_ids += [1] * vision_token_num - text_ids = tokenizer.encode(text, add_special_tokens=False) - - input_ids += text_ids - token_type_ids += [0] * len(text_ids) - attention_mask = [1] * len(input_ids) - - return { - "input_ids": torch.tensor(input_ids, dtype=torch.long), - "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long), - "attention_mask": torch.tensor(attention_mask, dtype=torch.long), - "images": images, - } +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.benchmark = True def get_gpu_memory_map(): nvmlInit() @@ -116,7 +72,7 @@ def save_params(args, gen_kwargs): def create_bnb_config(): return BitsAndBytesConfig( - bnb_4bit_compute_dtype="float32", + bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type= "fp4", bnb_4bit_use_double_quant=False, llm_int8_enable_fp32_cpu_offload=False, @@ -128,6 +84,121 @@ def create_bnb_config(): quant_method="bitsandbytes" ) +class BaseModelWrapper: + def __init__(self, model_name): + self.model_name = model_name + logging.info(f"Loading {model_name}") + + def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"): + self.model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + ).to(0) + + self.tokenizer = AutoProcessor.from_pretrained(self.model_name) + return self.model, self.tokenizer + + def get_gen_kwargs(self, args): + gen_kwargs = { + "max_length": args.max_length, + "do_sample": args.top_k is not None or args.top_p is not None or args.temp is not None or False, + "length_penalty": args.length_penalty, + "num_beams": args.num_beams, + "temperature": args.temp, + "top_k": args.top_k, + "top_p": args.top_p, + "repetition_penalty": args.repetition_penalty, + "no_repeat_ngram_size": args.no_repeat_ngram_size, + "min_new_tokens": args.min_new_tokens, + "max_new_tokens": args.max_new_tokens, + "length_penalty": args.length_penalty, + } + + logging.info(gen_kwargs) + + if args.max_new_tokens is not None: + logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length") + del gen_kwargs["max_length"] + + if not gen_kwargs["do_sample"]: + logging.info(f"** Using greedy sampling") + del gen_kwargs["top_k"] + del gen_kwargs["top_p"] + del gen_kwargs["temperature"] + else: + logging.info(f"** Sampling enabled") + return gen_kwargs + + def caption(prompt, args): + return "" + +class XtunerLlavaModelManager(BaseModelWrapper): # https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers + def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"): + self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers" + super().__init__(model_name) + + + def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"): + self.model = LlavaForConditionalGeneration.from_pretrained( + #self.model = AutoModelForCausalLM.from_pretrained( + self.model_name, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + #quantization_config=create_bnb_config() + ).to(0) + + self.processor = LlavaProcessor.from_pretrained(self.model_name) + self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers") + print(f"self.tokenizer: {self.tokenizer}") + # tokens = self.tokenizer("foo") + # print(f"foo tokens test1: {tokens}") + return self.model, self.tokenizer + + def get_inputs(self, image: Image.Image, prompt: str): + inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16) + return inputs + + def _build_conversational_input_ids(self, prompt, starts_with): + return (f"<|start_header_id|>user<|end_header_id|>\n\n\n{prompt}<|eot_id|>" + f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}") + + def _get_full_sentences(self, caption, args): + logging.debug(f"**DEBUG: XtunerLlava presplit caption: {caption}") + if args.max_length is not None and len(caption) > args.max_length: + caption = caption[:args.max_length] + + caption = caption.split(".") + #sentence_count = min(4, len(caption)) + caption = ". ".join(caption[0:-1]) + "." + + logging.debug(f"**DEBUG: caption: {caption}") + return caption + + def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]): + gen_kwargs = self.get_gen_kwargs(args) + + prompt = self._build_conversational_input_ids(prompt, args.starts_with) + inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16) + # inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16) + + inputs = { + "input_ids": inputs["input_ids"], + "attention_mask": inputs['attention_mask'], + "pixel_values": inputs['pixel_values'], + #"images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)] for _ in range(args.num_beams)], + #"output_hidden_states": True, + #"return_dict": True + } + len_inputs = inputs['input_ids'].shape[1] + + outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) + + caption = self.processor.decode(outputs[0][len_inputs:], skip_special_tokens=True) + + caption = self._get_full_sentences(caption, args) + return caption + class MoaiManager: def __init__(self, model_name: str): self.model_name = model_name @@ -155,8 +226,8 @@ class MoaiManager: return moai_model, moai_processor def get_inputs(self, image: Image.Image, prompt: str): - moai_inputs = self.moai_model.demo_process(image=image, - prompt=prompt, + moai_inputs = self.moai_model.demo_process(image=image, + prompt=prompt, processor=self.moai_processor, seg_model=self.seg_model, seg_processor=self.seg_processor, @@ -167,17 +238,17 @@ class MoaiManager: device="cuda:0") return moai_inputs - def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any: - with torch.inference_mode(): - generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache) - answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0] - return answer + # def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any: + # with torch.inference_mode(): + # generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache) + # answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0] + # return answer -class CogVLMManager: +class CogVLMManager(BaseModelWrapper): def __init__(self, model_name: str): - self.model_name = model_name - self.tokenizer = None - self.model = None + super().__init__(model_name) + self.model_name = "THUDM/cogvlm-chat-hf" + patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions def load_model(self): self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5") @@ -190,67 +261,120 @@ class CogVLMManager: ) return self.model, self.tokenizer - def get_inputs(self, prompt: str, history: List[Tuple[str, str]], images: List[Image.Image], starts_with: str): - return build_conversation_input_ids(self.tokenizer, query=prompt, history=history, images=images, starts_with=starts_with) + def _build_conversation_input_ids(self, + *, + query: str, + history: Optional[List[Tuple[str, str]]] = None, + images: Optional[List[Image.Image]] = None, + starts_with: Optional[str] = None, + ): + # based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py + image_size: int = IMAGE_SIZE + patch_size: int = PATCH_SIZE + assert images is None or len(images) <= 1, f"not support multi images by now." + history = history or [] - def get_gen_kwargs(self, args): - gen_kwargs = { - "max_length": args.max_length, - "do_sample": args.top_k is not None or args.top_p is not None or args.temp is not None or False, - "length_penalty": args.length_penalty, - "num_beams": args.num_beams, - "temperature": args.temp, - "top_k": args.top_k, - "top_p": args.top_p, - "repetition_penalty": args.repetition_penalty, - "no_repeat_ngram_size": args.no_repeat_ngram_size, - "min_new_tokens": args.min_new_tokens, - "max_new_tokens": args.max_new_tokens, - "length_penalty": args.length_penalty, + text = f"Question: {query} Answer: " + text += starts_with if starts_with is not None else "" + + input_ids = [self.tokenizer.bos_token_id] + token_type_ids = [0] + if images is not None and len(images) == 1: + # vision + transform = transforms.Compose( + [ + transforms.Resize( + (image_size, image_size), interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] + ) + images = [transform(images[0])] + vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2 + input_ids += [self.tokenizer.pad_token_id] * vision_token_num + token_type_ids += [1] * vision_token_num + text_ids = self.tokenizer.encode(text, add_special_tokens=False) + + input_ids += text_ids + token_type_ids += [0] * len(text_ids) + attention_mask = [1] * len(input_ids) + + return { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "token_type_ids": torch.tensor(token_type_ids, dtype=torch.long), + "attention_mask": torch.tensor(attention_mask, dtype=torch.long), + "images": images, } - print(gen_kwargs) - if args.max_new_tokens is not None: - logging.info(f"** max_new_tokens set to {args.max_new_tokens}, ignoring max_length") - del gen_kwargs["max_length"] - if not gen_kwargs["do_sample"]: - logging.info(f"** Using greedy sampling") - del gen_kwargs["top_k"] - del gen_kwargs["top_p"] - del gen_kwargs["temperature"] - else: - logging.info(f"** Sampling enabled") - return gen_kwargs + def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]): + gen_kwargs = self.get_gen_kwargs(args) -def model_manager_factory(model_name: str): + inputs = self._build_conversation_input_ids(query=prompt, history=history, images=[image], starts_with=args.starts_with) + + inputs = { + "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"), + "token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"), + "attention_mask": inputs['attention_mask'].unsqueeze(0).to("cuda"), + "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)] for _ in range(args.num_beams)], + "output_hidden_states": True, + "return_dict": True + } + outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) + #print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}") + #print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}") + + len_inputs = inputs['input_ids'].shape[1] + outputs_without_prompt = outputs[:, len_inputs:] + + caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True) + return caption + +def get_model_wrapper(model_name: str): if "moai" in model_name: return MoaiManager(model_name) + elif "llava" in model_name: + return XtunerLlavaModelManager(model_name) else: return CogVLMManager(model_name) +def get_inputs_dict(inputs): + inputs = { + "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"), + "token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"), + "attention_mask": inputs['attention_mask'].unsqueeze(0).to("cuda"), + "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)] for _ in range(args.num_beams)], + "output_hidden_states": True, + "return_dict": True + } + def main(args): prompt_plugin_fn = load_prompt_alteration_plugin(args.prompt_plugin, args=args) - model_manager = model_manager_factory(args.model) - - model, tokenizer = model_manager.load_model() + model_wrapper = get_model_wrapper(args.model) + model_wrapper.load_model() args.append = args.append or "" if len(args.append) > 0: args.append = " " + args.append.strip() - gen_kwargs = model_manager.get_gen_kwargs(args) + gen_kwargs = model_wrapper.get_gen_kwargs(args) force_words_ids = None if args.force_words is not None: force_words = args.force_words.split(",") if args.force_words is not None else [] logging.info(f"** force_words: {Fore.LIGHTGREEN_EX}{force_words}{Style.RESET_ALL}") - force_words_ids = tokenizer(force_words, add_special_tokens=False)["input_ids"] if force_words else [] + # if args.model contains "cog" + if "cog" in args.model: + force_words_ids = model_wrapper.tokenizer(force_words, add_special_tokens=False)["input_ids"] if force_words else [] + else: + force_words_ids = model_wrapper.tokenizer(force_words)["input_ids"] if force_words else [] bad_words_ids = None if args.bad_words is not None: bad_words = args.bad_words.split(",") if args.bad_words is not None else [] logging.info(f"** bad_words: {Fore.LIGHTGREEN_EX}{bad_words}{Style.RESET_ALL}") - bad_words_ids = tokenizer(bad_words, add_special_tokens=False)["input_ids"] if bad_words else [] + bad_words_ids = model_wrapper.tokenizer(bad_words, add_special_tokens=False)["input_ids"] if bad_words else [] + #print(bad_words_ids) logging.info(f"** gen_kwargs: \n{Fore.LIGHTGREEN_EX}{gen_kwargs}{Style.RESET_ALL}") @@ -277,41 +401,20 @@ def main(args): except Exception as e: logging.warning(f"Non-fatal error processing {image_path}: {e}") continue - + + pixel_count = image.height * image.width + if pixel_count < args.min_pixels: + logging.warning(f" * Image under {args.min_pixels} pixels, skipping. Path: {image_path}") + continue + logging.debug(f" __ Prompt before plugin: {Fore.LIGHTGREEN_EX}{args.prompt}{Style.RESET_ALL}") prompt = prompt_plugin_fn(image_path, args=args) logging.debug(f" __ Modified prompt after plugin: {Fore.LIGHTGREEN_EX}{prompt}{Style.RESET_ALL}") - inputs = build_conversation_input_ids(tokenizer, query=prompt, history=[], images=[image], starts_with=args.starts_with) # chat mode - - inputs = { - "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"), - "token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"), - "attention_mask": inputs['attention_mask'].unsqueeze(0).to("cuda"), - "images": [[inputs["images"][0].to("cuda").to(torch.bfloat16)] for _ in range(args.num_beams)], - "output_hidden_states": True, - "return_dict": True - } - - # print(f"** inputs type: {type(inputs)}") # dict - # print(f"** inputs len: {len(inputs)}") # 4 - # print(f"** inputs keys: {inputs.keys()}") # dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'images']) - # print(f"** inputs['images'] shape: {inputs['images'].shape}") # list has no shape - # print(f"** image_path: {image_path}") - with torch.no_grad(): - #input_decoded = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) - #logging.debug(f"inputs decoded: {input_decoded}") - #print(f"calling generate with input shapes: {inputs['input_ids'].shape}, {inputs['token_type_ids'].shape}, {inputs['attention_mask'].shape}, {inputs['images'][0][0].shape}") - #calling generate with input shapes: torch.Size([1, 1352]), torch.Size([1, 1352]), torch.Size([1, 1352]), torch.Size([3, 490, 490]) - outputs = model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) - print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}") - #print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}") + #def caption(self, prompt, images, args, force_words_ids, bad_words_ids, history=[]): + caption = model_wrapper.caption(prompt, image, args, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) - len_inputs = inputs['input_ids'].shape[1] - outputs_without_prompt = outputs[:, len_inputs:] - - caption = tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True) if not args.remove_starts_with: # deal with caption starting with comma, etc if not re.match(r"^\W", caption): @@ -325,7 +428,7 @@ def main(args): f.write(caption) vram_gb = get_gpu_memory_map() elapsed_time = time.time() - cap_start_time - logging.info(f"n:{i:05}, VRAM: {Fore.LIGHTYELLOW_EX}{vram_gb:0.1f} GB{Style.RESET_ALL}, elapsed: {Fore.LIGHTYELLOW_EX}{elapsed_time:0.1f}{Style.RESET_ALL} sec, Captioned {Fore.LIGHTYELLOW_EX}{image_path}{Style.RESET_ALL}: ") + logging.info(f"n:{i:05}, VRAM: {Fore.LIGHTYELLOW_EX}{vram_gb:0.1f} GB{Style.RESET_ALL}, elapsed: {Fore.LIGHTYELLOW_EX}{elapsed_time:0.1f}{Style.RESET_ALL} sec, sqrt_pixels: {pow(float(pixel_count),0.5):0.1f}, Captioned {Fore.LIGHTYELLOW_EX}{image_path}{Style.RESET_ALL}: ") logging.info(f"{Fore.LIGHTCYAN_EX}{caption}{Style.RESET_ALL}") i_processed += 1 @@ -339,26 +442,13 @@ def main(args): logging.info(f"** Done captioning {args.image_dir} with prompt '{prompt}', total elapsed: {hh_mm_ss} (hh_mm_ss), avg: {avg_time:0.1f} sec/image") -def configure_logging(args: argparse.Namespace): - level = logging.INFO if not args.debug else logging.DEBUG - filemode = "a" if args.append_log else "w" - logging.basicConfig(filename="caption_cog.log", - level=level, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - filemode=filemode) - - console = logging.StreamHandler() - console.setLevel(level) - console.setFormatter(logging.Formatter('%(message)s')) - logging.getLogger('').addHandler(console) - EXAMPLES = """ex. Basic example: python caption_cog.py --image_dir /mnt/mydata/kyrie/ --prompt 'Describe this image in detail, including the subject matter and medium of the artwork.' Use probabilistic sampling by using any of top_k, top_p, or temp: - python caption_cog.py --image_dir \"c:/users/chadley/my documents/pictures\" --prompt \"What is this?\" --top_p 0.9 - + python caption_cog.py --image_dir \"c:/users/chadley/my documents/pictures\" --prompt \"What is this?\" --top_p 0.9 + Use beam search and probabilistic sampling: python caption_cog.py --image_dir \"c:/users/chadley/my documents/pictures\" --prompt \"Write a description.\" --max_new_tokens 75 --num_beams 4 --temp 0.9 --top_k 3 --top_p 0.9 --repetition_penalty 1.0 --no_repeat_ngram_size 0 --min_new_tokens 5 @@ -409,9 +499,10 @@ if __name__ == "__main__": argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.") argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.") argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.") + argparser.add_argument("--min_pixels", type=int, default=1, help="Minimum total pixel size to caption, under the limit will be skipped") args, unknown_args = argparser.parse_known_args() - - configure_logging(args) + + configure_logging(args, "caption_cog.log") unknown_args_dict = {} for i in range(0, len(unknown_args), 2): diff --git a/doc/ADVANCED_TWEAKING.md b/doc/ADVANCED_TWEAKING.md index 1dd371b..3dbd194 100644 --- a/doc/ADVANCED_TWEAKING.md +++ b/doc/ADVANCED_TWEAKING.md @@ -46,9 +46,9 @@ This may also be useful to really "force" a style into the model with a high set ## Timestep clamping -Stable Diffusion uses 1000 possible timesteps for denoising steps. If you wish to train only a portion of those timesteps instead of the entire schedule you can clamp the value. +Stable Diffusion uses 1000 possible timesteps for denoising steps. Timesteps are always chosen randomly per training example, per step, within the possible or allowed timesteps. -Timesteps are always chosen randomly per training example, per step, within the possible or allowed timesteps. +If you wish to train only a portion of those timesteps instead of the entire schedule you can clamp the value. For instance, if you only want to train from 500 to 999, use this: @@ -58,7 +58,9 @@ Or if you only want to try from 0 to 449, use this: --timestep_end 450 -Possible use cases are to "focus" training on aesthetics or composition. It's likely you may need to train all timesteps as a "clean up" if you train just specific timestep ranges first. +Possible use cases are to "focus" training on aesthetics or composition by limiting timesteps and training specific data with certain qualities. It's likely you may need to train all timesteps as a "clean up" if you train just specific timestep ranges first so the model does not overfit the fine tuned timesteps and lead to problems during inference. + +This could also be used to train expert models for specific timestep ranges, similar to the SDXL Refiner model. ## Loss Type diff --git a/plugins/caption_plugins.py b/plugins/caption_plugins.py index 0af848a..b263346 100644 --- a/plugins/caption_plugins.py +++ b/plugins/caption_plugins.py @@ -228,6 +228,54 @@ class TitleAndTagsFromImageJson(PromptIdentityBase): logging.debug(f" {self.key}: prompt after: {prompt}") return prompt +class VogueRunwayImageJson(PromptIdentityBase): + def __init__(self, args:Namespace=None): + super().__init__(key="vogue_runway_from_image_json", + description="Adds title and tags hint from metadata.json (in the samefolder as the image) to the prompt", + fn=self._title_and_tags_from_metadata_json, + args=args) + + def try_get_kvps(self, metadata, keys:list): + values = [] + for key in keys: + val = metadata.get(key, "") + if not val: + continue + if type(val) == int: + val = str(val) + val = val.strip() + values.append(f"{key}: {val}") + hint = ", ".join(values) + return hint + + def _title_and_tags_from_metadata_json(self, args:Namespace) -> str: + prompt = args.prompt + logging.debug(f" {self.key}: prompt before: {prompt}") + image_path = args.image_path + current_dir = os.path.dirname(image_path) + image_path_base = os.path.basename(image_path) + image_path_without_extension = os.path.splitext(image_path_base)[0] + candidate_json_path = os.path.join(current_dir, f"{image_path_without_extension}.json") + + if os.path.exists(candidate_json_path): + with open(candidate_json_path, "r") as f: + metadata = json.load(f) + + keys = ["designer","season","category","year"] + + hint = "" + hint = self.try_get_kvps(metadata, keys) + + tags = metadata.get("tags", []) + tags = tags.split(",") if isinstance(tags, str) else tags # can be csv or list + if tags and len(tags) > 0: + tags = ", ".join(tags) + hint += f"\nTags: {tags}" + + prompt = self._add_hint_to_prompt(hint, prompt) + logging.debug(f" {self.key}: prompt after: {prompt}") + return prompt + class TitleAndTagsFromFolderMetadataJson(PromptIdentityBase): def __init__(self, args:Namespace=None): self.cache = {} diff --git a/utils/ed_logging.py b/utils/ed_logging.py new file mode 100644 index 0000000..d604d95 --- /dev/null +++ b/utils/ed_logging.py @@ -0,0 +1,17 @@ +import logging +import argparse + +def configure_logging(args: argparse.Namespace, log_file=None): + level = logging.INFO if not args.debug else logging.DEBUG + + if log_file: + filemode = "a" if args.append_log else "w" + logging.basicConfig(filename=log_file, + level=level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + filemode=filemode) + + console = logging.StreamHandler() + console.setLevel(level) + console.setFormatter(logging.Formatter('%(message)s')) + logging.getLogger('').addHandler(console) \ No newline at end of file