diff --git a/caption_cog.py b/caption_cog.py index 1cd6d82..321ca15 100644 --- a/caption_cog.py +++ b/caption_cog.py @@ -47,7 +47,8 @@ except ImportError: Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit -IMAGE_SIZE: int = 490 +IMAGE_SIZE_COG1: int = 490 +IMAGE_SIZE_COG2: int = 1344 PATCH_SIZE: int = 14 torch.backends.cuda.matmul.allow_tf32 = True @@ -70,7 +71,7 @@ def save_params(args, gen_kwargs): with open(save_path, "w") as f: f.write(pretty_print) -def create_bnb_config(bnb_4bit_compute_dtype="bfloat16",bnb_4bit_quant_type= "fp4"): +def create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type= "fp4"): return BitsAndBytesConfig( bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, bnb_4bit_quant_type=bnb_4bit_quant_type, @@ -89,15 +90,26 @@ class BaseModelWrapper: self.model_name = model_name logging.info(f"Loading {model_name}") - def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"): + def load_model(self, dtype: str="auto"): + bnb_config = self._maybe_create_bnb_config(dtype) self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.float16, low_cpu_mem_usage=True, + quantization_config = bnb_config ).to(0) self.tokenizer = AutoProcessor.from_pretrained(self.model_name) return self.model, self.tokenizer + + def _maybe_create_bnb_config(self, dtype, auto_bnb=True, auto_bnb_dtype="fp4"): + bnb_config = None + if dtype == "auto": + if auto_bnb: + bnb_config = create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type=auto_bnb_dtype) + if dtype in ["nf4", "fp4"]: + bnb_config = create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type=dtype) + return bnb_config def get_gen_kwargs(self, args): gen_kwargs = { @@ -129,47 +141,7 @@ class BaseModelWrapper: else: logging.debug(f"** Sampling enabled") return gen_kwargs - - def caption(prompt, args): - return "" - -class XtunerLlavaModelManager(BaseModelWrapper): - def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"): - self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers" - super().__init__(model_name) - - def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"): - self.model = LlavaForConditionalGeneration.from_pretrained( - #self.model = AutoModelForCausalLM.from_pretrained( - self.model_name, - torch_dtype=torch.float16, - low_cpu_mem_usage=True, - #quantization_config=create_bnb_config() - ).to(0) - - self.processor = LlavaProcessor.from_pretrained(self.model_name) - self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers") - print(f"self.tokenizer: {self.tokenizer}") - # tokens = self.tokenizer("foo") - # print(f"foo tokens test1: {tokens}") - return self.model, self.tokenizer - - def get_inputs(self, image: Image.Image, prompt: str): - inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16) - return inputs - - def _build_conversational_input_ids(self, prompt, starts_with): - return (f"<|start_header_id|>user<|end_header_id|>\n\n\n{prompt}<|eot_id|>" - f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}") - - def _truncate_to_whole_sentences(self, caption): - # model does not stop generating cleanly and cuts off mid sentence - caption = caption.split(".") - caption = ". ".join(caption[0:-1]) + "." - caption = caption.replace("\n","") - caption = caption.replace(" "," ") - return caption - + def _clean_caption(self, caption, args): """ Removes some nonsense Llava adds. @@ -194,11 +166,52 @@ class XtunerLlavaModelManager(BaseModelWrapper): caption = caption.replace(", who is the main subject of the photo.", ".") caption = caption.replace(", who is the main subject.", ".") caption = caption.replace("who is the main subject.", ".") + caption = caption.replace(", who is the central focus of the composition.", ".") + caption = caption.replace("who is the central focus of the composition.", ".") caption = self._truncate_to_whole_sentences(caption) logging.debug(f"**Llava post-cleaning caption: {caption}") return caption + def caption(prompt, args): + return "" + +class XtunerLlavaModelManager(BaseModelWrapper): + def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"): + self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers" + super().__init__(model_name) + logging.info("Loading Xtuner Llava-Llama3 model...") + + def load_model(self, dtype="auto"): + bnb_config = self._maybe_create_bnb_config(dtype, auto_bnb=False) + self.model = LlavaForConditionalGeneration.from_pretrained( + self.model_name, + torch_dtype=torch.float16, + low_cpu_mem_usage=True, + quantization_config=bnb_config + ).to("cuda") + + self.processor = LlavaProcessor.from_pretrained(self.model_name) + self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers") + + return self.model, self.tokenizer + + def get_inputs(self, image: Image.Image, prompt: str): + inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16) + return inputs + + def _build_conversational_input_ids(self, prompt, starts_with): + return (f"<|start_header_id|>user<|end_header_id|>\n\n\n{prompt}<|eot_id|>" + f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}") + + def _truncate_to_whole_sentences(self, caption): + # model does not stop generating cleanly and cuts off mid sentence + caption = caption.split(".") + caption = ". ".join(caption[0:-1]) + "." + caption = caption.replace("\n","") + caption = caption.replace(" "," ") + return caption + def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]): gen_kwargs = self.get_gen_kwargs(args) @@ -227,59 +240,110 @@ class XtunerLlavaModelManager(BaseModelWrapper): caption = self._clean_caption(caption, args) return caption -class MoaiManager: +# class MoaiManager: +# def __init__(self, model_name: str): +# self.model_name = model_name +# self.moai_model = None +# self.moai_processor = None +# self.seg_model = None +# self.seg_processor = None +# self.od_model = None +# self.od_processor = None +# self.sgg_model = None +# self.ocr_model = None +# logging.info("Loading Moai model...") + +# def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"): +# moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \ +# = prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype) +# self.moai_model = moai_model +# self.moai_processor = moai_processor +# self.seg_model = seg_model +# self.seg_processor = seg_processor +# self.od_model = od_model +# self.od_processor = od_processor +# self.sgg_model = sgg_model +# self.ocr_model = ocr_model + +# return moai_model, moai_processor + +# def get_inputs(self, image: Image.Image, prompt: str): +# moai_inputs = self.moai_model.demo_process(image=image, +# prompt=prompt, +# processor=self.moai_processor, +# seg_model=self.seg_model, +# seg_processor=self.seg_processor, +# od_model=self.od_model, +# od_processor=self.od_processor, +# sgg_model=self.sgg_model, +# ocr_model=self.ocr_model, +# device="cuda:0") +# return moai_inputs + +class CogGLMManager(BaseModelWrapper): def __init__(self, model_name: str): - self.model_name = model_name - self.moai_model = None - self.moai_processor = None - self.seg_model = None - self.seg_processor = None - self.od_model = None - self.od_processor = None - self.sgg_model = None - self.ocr_model = None + super().__init__(model_name) + if not model_name: + self.model_name = "THUDM/cogglm-6b" + else: + self.model_name = model_name + logging.info("Loading CogGLM model...") - def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"): - moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \ - = prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype) - self.moai_model = moai_model - self.moai_processor = moai_processor - self.seg_model = seg_model - self.seg_processor = seg_processor - self.od_model = od_model - self.od_processor = od_processor - self.sgg_model = sgg_model - self.ocr_model = ocr_model + def load_model(self, dtype: str = "auto"): + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + bnb_config = None + if dtype in ["auto","nf4"]: + bnb_config = create_bnb_config() + self.model = model = AutoModelForCausalLM.from_pretrained( + "THUDM/glm-4v-9b", + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + trust_remote_code=True, + quantization_config=bnb_config + ).eval() + if bnb_config is None: + # if BNB is used it is automatically sent to cuda device, otherwise need to move it manually + self.model = model.to("cuda") - return moai_model, moai_processor + def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]): + gen_kwargs = self.get_gen_kwargs(args) - def get_inputs(self, image: Image.Image, prompt: str): - moai_inputs = self.moai_model.demo_process(image=image, - prompt=prompt, - processor=self.moai_processor, - seg_model=self.seg_model, - seg_processor=self.seg_processor, - od_model=self.od_model, - od_processor=self.od_processor, - sgg_model=self.sgg_model, - ocr_model=self.ocr_model, - device="cuda:0") - return moai_inputs + inputs = self.tokenizer.apply_chat_template([{"role": "user", "image": image, "content": prompt}], + add_generation_prompt=True, tokenize=True, return_tensors="pt", + return_dict=True) + inputs.to("cuda") - # def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any: - # with torch.inference_mode(): - # generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache) - # answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0] - # return answer + outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) + + len_inputs = inputs['input_ids'].shape[1] + outputs_without_prompt = outputs[:, len_inputs:] + + caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True) + return caption class CogVLMManager(BaseModelWrapper): def __init__(self, model_name: str): super().__init__(model_name) - self.model_name = "THUDM/cogvlm-chat-hf" + if not model_name: + self.model_name = "THUDM/cogvlm-chat-hf" + self.cog_version = 1 + elif model_name.lower() == "THUDM/cogvlm2-llama3-chat-19b".lower(): + self.model_name = "THUDM/cogvlm2-llama3-chat-19b" + self.cog_version = 2 + else: + self.model_name = model_name + self.cog_version = 1 patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions + logging.info("Loading CogVLM model...") - def load_model(self): - self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5") + def load_model(self, dtype: str = "auto"): + if self.model_name.lower() == "THUDM/cogvlm-chat-hf".lower(): + self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5") + elif self.model_name.lower() == "THUDM/cogvlm2-llama3-chat-19b".lower(): + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True) + self.tokenizer.pad_token_id = 128002 # for Llama 3 + else: + raise ValueError("Unknown model name") self.model = AutoModelForCausalLM.from_pretrained( self.model_name, torch_dtype=torch.bfloat16, @@ -297,7 +361,7 @@ class CogVLMManager(BaseModelWrapper): starts_with: Optional[str] = None, ): # based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py - image_size: int = IMAGE_SIZE + image_size: int = IMAGE_SIZE_COG2 if self.cog_version == 2 else IMAGE_SIZE_COG1 patch_size: int = PATCH_SIZE assert images is None or len(images) <= 1, f"not support multi images by now." history = history or [] @@ -306,9 +370,8 @@ class CogVLMManager(BaseModelWrapper): text += starts_with if starts_with is not None else "" input_ids = [self.tokenizer.bos_token_id] - token_type_ids = [0] + token_type_ids = [0] # LANGUAGE_TOKEN_TYPE if images is not None and len(images) == 1: - # vision transform = transforms.Compose( [ transforms.Resize( @@ -319,7 +382,11 @@ class CogVLMManager(BaseModelWrapper): ] ) images = [transform(images[0])] - vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2 + if self.cog_version == 1: + vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2 + elif self.cog_version == 2: + vision_token_num = (image_size // patch_size // 2) * (image_size // patch_size // 2) + 2 + input_ids += [self.tokenizer.pad_token_id] * vision_token_num token_type_ids += [1] * vision_token_num text_ids = self.tokenizer.encode(text, add_special_tokens=False) @@ -340,10 +407,6 @@ class CogVLMManager(BaseModelWrapper): inputs = self._build_conversation_input_ids(query=prompt, history=history, images=[image], starts_with=args.starts_with) - # inputs['input_ids'].shape: torch.Size([1259]) - # inputs['attention_mask'].shape: torch.Size([1259]) - # inputs['images'][0].shape: torch.Size([3, 490, 490]) - inputs = { "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"), "token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"), @@ -352,15 +415,8 @@ class CogVLMManager(BaseModelWrapper): "output_hidden_states": True, "return_dict": True } - # inputs['input_ids'].shape: torch.Size([1, 1259]) - # inputs['attention_mask'].shape: torch.Size([1, 1259]) - # inputs['images'][0][0].shape: torch.Size([3, 490, 490]) - # len(inputs['images'][0]): 1 - # len(inputs['images'][0][0]): 3 outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) - #print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}") - #print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}") len_inputs = inputs['input_ids'].shape[1] outputs_without_prompt = outputs[:, len_inputs:] @@ -369,12 +425,22 @@ class CogVLMManager(BaseModelWrapper): return caption def get_model_wrapper(model_name: str): - if "moai" in model_name: - return MoaiManager(model_name) - elif "llava" in model_name: - return XtunerLlavaModelManager(model_name) - else: - return CogVLMManager(model_name) + match model_name.casefold(): + # case x if "moai" in x: + # #return MoaiManager(model_name) + # return None + case x if "llava" in x: + return XtunerLlavaModelManager(model_name) + case "thudm/glm-4v-9b": + return CogGLMManager(model_name) + case "thudm/cogvlm2-llama3-chat-19b": + return CogVLMManager(model_name) + case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]: + return CogVLMManager(model_name) + case None: + return CogVLMManager(model_name) + case _: + raise ValueError(f"Model {model_name} not supported") def get_inputs_dict(inputs): inputs = { @@ -518,6 +584,7 @@ if __name__ == "__main__": argparser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch processing. Does NOT work with COG! (def: 1)") argparser.add_argument("--debug", action="store_true", help="Enable debug logging") argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.") + argparser.add_argument("--dtype", choices=["auto","fp16","bf16","nf4","fp4"], default="auto", help="Data type for inference (def: auto, see docs)") argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling") argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)") argparser.add_argument("--top_k", type=int, default=None, help="Top-k, filter k highest probability tokens before sampling") @@ -540,7 +607,7 @@ if __name__ == "__main__": argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.") argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.") argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.") - argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.") + argparser.add_argument("--model", type=str, default=None, help="Model to use for captioning.") argparser.add_argument("--min_pixels", type=int, default=1, help="Minimum total pixel size to caption, under the limit will be skipped") args, unknown_args = argparser.parse_known_args() diff --git a/doc/CAPTION_COG.md b/doc/CAPTION_COG.md index af8e3f5..f44d233 100644 --- a/doc/CAPTION_COG.md +++ b/doc/CAPTION_COG.md @@ -8,13 +8,23 @@ It is capable of naming and identifying things with proper nouns and has a large Open In Colab +Both the ([Vicuna-based](https://huggingface.co/THUDM/cogvlm-chat-hf)) and ([Llama3-based](https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B)) models are supported. + +Choose these by using one of these two CLI args: + + --model THUDM/cogvlm-chat-hf + + --model THUDM/cogvlm2-llama3-chat-19B + +The script uses the Vicuna model (first) by default if no `--model` arg is specified. + ## Llava update This script now (confusiningly) supports (Xtuner's Llava Llama3 8b v1.1)[https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main]. To use, add `--model "https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main"` to your command line. -This is a work in progress. So far it seems bad_words do not work. +When using Llava, the script will perform some clean-up operations to remove some less-than-useful language from the caption because the bad_words part of the Hugginface Transformers API is not supported by Llava. ## Basics diff --git a/docker/Dockerfile b/docker/Dockerfile index 7667ba0..f965bbd 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,7 +5,7 @@ LABEL org.opencontainers.image.licenses="AGPL-3.0-only" ARG DEBIAN_FRONTEND=noninteractive # Don't write .pyc bytecode -ENV PYTHONDONTWRITEBYTECODE=1 +# ENV PYTHONDONTWRITEBYTECODE=1 # Create workspace working directory RUN mkdir /build @@ -49,7 +49,7 @@ ENV DEBIAN_FRONTEND noninteractive\ ENV PYTHONUNBUFFERED=1 # Don't write .pyc bytecode -ENV PYTHONDONTWRITEBYTECODE=1 +# ENV PYTHONDONTWRITEBYTECODE=1 RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \ @@ -65,7 +65,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ echo "en_US.UTF-8 UTF-8" > /etc/locale.gen # Install runpodctl -RUN wget https://github.com/runpod/runpodctl/releases/download/v1.9.0/runpodctl-linux-amd -O runpodctl && \ +RUN wget https://github.com/runpod/runpodctl/releases/download/v1.14.3/runpodctl-linux-amd64 -O runpodctl && \ chmod a+x runpodctl && \ mv runpodctl /usr/local/bin diff --git a/docker/requirements-build.txt b/docker/requirements-build.txt index 46132f2..1b96e6c 100644 --- a/docker/requirements-build.txt +++ b/docker/requirements-build.txt @@ -1,4 +1,4 @@ -diffusers[torch]>=0.21.4 +diffusers[torch]>=0.27.2 ninja numpy omegaconf==2.2.3 diff --git a/docker/requirements-runtime.txt b/docker/requirements-runtime.txt index 2040781..b00687f 100644 --- a/docker/requirements-runtime.txt +++ b/docker/requirements-runtime.txt @@ -19,4 +19,5 @@ safetensors prodigyopt torchsde peft==0.9.0 -unidecode \ No newline at end of file +unidecode +tiktoken \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 078defe..61e6d06 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,4 +22,5 @@ numpy==1.23.5 wandb colorama safetensors -torchsde \ No newline at end of file +torchsde +tiktoken \ No newline at end of file diff --git a/scripts/check_images.py b/scripts/check_images.py index 81bac44..50dcdb5 100644 --- a/scripts/check_images.py +++ b/scripts/check_images.py @@ -26,6 +26,7 @@ def main(args): except Exception as e: print(f"FAILED: {path}") failed.append((path,e)) + if not failed: print("No errors found") else: @@ -33,7 +34,6 @@ def main(args): for path, e in failed: print(f"FAILED: {path} {e}") - if __name__ == '__main__': print("This script checks that all images in a directory are valid.") print("If any errors occur, they will be printed out at the end.") diff --git a/scripts/mt_grid.py b/scripts/mt_grid.py new file mode 100644 index 0000000..ebe3433 --- /dev/null +++ b/scripts/mt_grid.py @@ -0,0 +1,88 @@ +from diffusers import StableDiffusionPipeline +import torch +from torch.cuda.amp import autocast +import os + +import argparse +from PIL import Image, ImageDraw, ImageFont + +def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen, + steps: int = 30, batch_size: int = 1): + """ + generates a single sample at a given cfg scale and saves it to disk + """ + with autocast(): + images = pipe(prompt, + num_inference_steps=steps, + num_images_per_prompt=batch_size, + guidance_scale=cfg, + generator=gen, + height=height, + width=width, + ).images + + return images + +def generate_simple(prompt, model): + pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda") + images = __generate_sample(pipe, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1) + return images[0] + +if __name__ == "__main__": + argparser = argparse.ArgumentParser() + argparser.add_argument("--epochs", type=int, default=60) + args = argparser.parse_args() + epochs = args.epochs + + path = "/mnt/nvme/mt/val" + + model = None + if epochs == 100: + model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep100-gs159000" + model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep100-gs159000" + elif epochs == 80: + model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep80-gs127200" + model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep80-gs127200" + elif epochs == 60: + model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400" + model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep60-gs95400" + else: + raise ValueError("epochs must be 100, 80, or 60") + + pipe1 = StableDiffusionPipeline.from_pretrained(model1).to("cuda") + pipe2 = StableDiffusionPipeline.from_pretrained(model2).to("cuda") + + for root, dirs, files in os.walk(path): + for file in files: + if file.endswith(".txt") and not file.endswith("file_list.txt"): + txt_path = os.path.join(root, file) + with open(txt_path, "r", encoding="utf-8") as f: + prompt = f.readline() + generated_image1 = __generate_sample(pipe1, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0] + short_prompt = prompt.split(",")[0] + generated_image2 = __generate_sample(pipe2, short_prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0] + print(short_prompt) + gt_path = txt_path.replace(".txt", ".png") + print(f"Loading gt_path {gt_path}") + + ground_truth_image = Image.open(gt_path) + ground_truth_image = ground_truth_image.resize((512, 512)) + + combined_image = Image.new("RGB", (1536, 576), color=(96, 96, 96)) + combined_image.paste(ground_truth_image, (0, 0)) + combined_image.paste(generated_image1, (512, 0)) + combined_image.paste(generated_image2, (1024, 0)) + + draw = ImageDraw.Draw(combined_image) + font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 18) + draw.text((0, 510), f"epochs={epochs}", font=font) + draw.text((200, 510), "↑ ground truth ↑", font=font) + draw.text((650, 510), "↑ trained&generated full caption↑", font=font) + draw.text((1140, 510), "↑ trained&generated short caption ↑", font=font) + font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 24) + draw.text((100, 536), prompt, font=font) + draw.text((1240, 537), short_prompt, font=font) + + output_path = os.path.join("/mnt/nvme/mt", str(epochs), f"{file}_compare.png") + print(f"Saving to {output_path}") + combined_image.save(output_path) diff --git a/scripts/split_val.py b/scripts/split_val.py new file mode 100644 index 0000000..0694efa --- /dev/null +++ b/scripts/split_val.py @@ -0,0 +1,66 @@ +""" +Copyright [2022] Victor C Hall + +Licensed under the GNU Affero General Public License; +You may not use this code except in compliance with the License. +You may obtain a copy of the License at + + https://www.gnu.org/licenses/agpl-3.0.en.html + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +## Script to move random files from source_root to destination_root for use as validation split +## chooses (num_pairs_to_move) image/caption pairs from each subfolder in source_root and moves them to destination_root, preserving subfolder structure +## also creates a validation_captions.txt file in destination_root for inference testing later to see how model performs on unseen data +## only works on (png|jpg) + txt pairs. Will break if there is no .txt file or images are other extensions + +import os +import shutil +import random + +source_root = "/mnt/nvme/mydata_train" +destination_root = "/mnt/nvme/mydata_val" + +num_pairs_to_move = 3 # TODO: also do % based moving instead of fixed number + +def move_random_file_pairs(source_folder, destination_folder): + with open(os.path.join(destination_folder, "validation_captions.txt"), "w") as f: + for subdir, dirs, files in os.walk(source_folder): + for dir in dirs: + source_subfolder = os.path.join(subdir, dir) + destination_subfolder = os.path.join(destination_folder, dir) + + if not os.path.exists(destination_subfolder): + os.makedirs(destination_subfolder) + + file_list = [f for f in os.listdir(source_subfolder) if f.endswith((".png")) or f.endswith((".jpg"))] + + if len(file_list) >= num_pairs_to_move: + random.shuffle(file_list) + + for i in range(num_pairs_to_move): + file_name = file_list[i] + source_file = os.path.join(source_subfolder, file_name) + destination_file = os.path.join(destination_subfolder, file_name) + + caption_file_name = os.path.splitext(file_name)[0] + ".txt" + caption_source_file = os.path.join(source_subfolder, caption_file_name) + caption_destination_file = os.path.join(destination_subfolder, caption_file_name) + + with open(caption_source_file, "r") as caption_source: + caption = caption_source.readline() + f.write(caption) + f.write("\n") + print(f"Moving {caption_source_file} to {caption_destination_file}") + print(f"Moving {source_file} to {destination_file}") + print(f"Caption: {caption}\n") + shutil.move(source_file, destination_file) + shutil.move(caption_source_file, caption_destination_file) + + +move_random_file_pairs(source_root, destination_root) diff --git a/scripts/txt2img_grid_from_txt.py b/scripts/txt2img_grid_from_txt.py new file mode 100644 index 0000000..258173b --- /dev/null +++ b/scripts/txt2img_grid_from_txt.py @@ -0,0 +1,113 @@ +from diffusers import StableDiffusionPipeline +import torch +from torch.cuda.amp import autocast +import os + +import argparse +from PIL import Image + +def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen, + steps: int = 30, batch_size: int = 1): + """ + generates a single sample at a given cfg scale and saves it to disk + """ + with autocast(): + images = pipe(prompt, + num_inference_steps=steps, + num_images_per_prompt=batch_size, + guidance_scale=cfg, + generator=gen, + height=height, + width=width, + ).images + + return images + +def simple(): + pipe = StableDiffusionPipeline.from_pretrained("/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400").to("cuda") + images = __generate_sample(pipe, "bicycle", cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1) + images[0].save("test.png") + +if __name__ == "__main__": + #simple() + argparser = argparse.ArgumentParser() + argparser.add_argument("--prompt_file", type=str, required=False) + argparser.add_argument("--val_data_path", type=str, default="/mnt/nvme/mt/val", required=False) + argparser.add_argument("--models", nargs="+", help="names of models") + args = argparser.parse_args() + args.val_data_path = "/mnt/nvme/mt/val/00b42" + + W = 512 + H = 512 + BATCH_SIZE = 4 + + print(f"Generating grid image for {len(args.models)} models and {args.prompt_file}") + #print each args.models + print("Models:") + for m in args.models: + print(f" {m}") + + # with open(args.prompt_file, "r") as f: + # prompt_master_list = [] + # for x, line in enumerate(f): + # prompt_master_list.append(line.strip()) + + # open the txt files in args.val_data_path + prompt_master_list = {} + for f in os.listdir(args.val_data_path): + if f.endswith(".txt"): + txt_path = os.path.join(args.val_data_path, f) + with open(os.path.join(args.val_data_path, f), "r", encoding="utf-8") as f2: + img_path = os.path.splitext(f)[0] + ".png" + img_path = os.path.join(args.val_data_path, img_path) + prompt_master_list[img_path] = f2.readline().strip() + + print(f"Found {len(prompt_master_list)} images in {args.val_data_path}") + print(f"First 10 images: {list(prompt_master_list.values())[:10]}") + print() + + num_lines = len(prompt_master_list) + grid_h = (num_lines + 1) * W # num images plus blank for left column labels + grid_w = (1 + len(args.models)) * H # num models plus blank for top row labels + grid_img = Image.new("RGB", (grid_w, grid_h)) + + #num_iterations = len(prompt_master_list) // BATCH_SIZE + (len(prompt_master_list) % BATCH_SIZE > 0) + + chunked_dict_list = [] + chunk = {} + for key, value in prompt_master_list.items(): + chunk[key] = value + if len(chunk) == BATCH_SIZE: + chunked_dict_list.append(chunk) + chunk = {} + + # Append any remaining items if the total number of items is not a multiple of chunk_size + if chunk: + chunked_dict_list.append(chunk) + + # Iterate through the chunks + for i, chunk in enumerate(chunked_dict_list): + print(f"Chunk {i + 1}: {chunk}") + exit() + + for i_m, model in enumerate(args.models): + for j_p in range(chunked_dict_list): + start_index = j_p * BATCH_SIZE + end_index = (j_p + 1) * BATCH_SIZE + current_prompts = prompt_master_list[start_index:end_index] + + print(f"{model}: {current_prompts}") + print() + + if True: + pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda") + seed_generator = torch.Generator(pipe.device).manual_seed(555) + images = __generate_sample(pipe, current_prompts, cfg=7.5, height=512, width=512, gen=seed_generator, steps=40, batch_size=BATCH_SIZE) + # paste each image into the grid starting from H,W and incrementing by W + for k, k_img in enumerate(images): + k_img.save(f"tmp/{i_m}_{k}.png") + grid_img.paste(k_img, (W+k*W, H+H*i_m)) + # save the grid image + grid_img.save(f"tmp/grid.png") + + \ No newline at end of file diff --git a/train.json b/train.json index 4cc1b56..b646c4d 100644 --- a/train.json +++ b/train.json @@ -8,6 +8,7 @@ "data_root": "/mnt/q/training_samples/ff7r/man", "disable_amp": false, "disable_textenc_training": false, + "embedding_perturbation": 0.0, "flip_p": 0.0, "gpuid": 0, "gradient_checkpointing": true, @@ -28,8 +29,7 @@ "save_ckpts_from_n_epochs": 0, "save_every_n_epochs": 20, "save_optimizer": false, - "scale_lr": false, - "seed": 555, + "seed": -1, "timestep_start": 0, "timestep_end": 1000, "shuffle_tags": false, diff --git a/train.py b/train.py index 023bab5..d0dde5f 100644 --- a/train.py +++ b/train.py @@ -733,6 +733,7 @@ def main(args): try: print() + # currently broken on most systems? #unet = torch.compile(unet, mode="max-autotune") #text_encoder = torch.compile(text_encoder, mode="max-autotune") #vae = torch.compile(vae, mode="max-autotune") @@ -833,7 +834,6 @@ def main(args): logging.info( f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_update_interval: {args.ema_update_interval}, ema_device: {args.ema_device}.") - ed_optimizer = EveryDreamOptimizer(args, optimizer_config, text_encoder, @@ -931,7 +931,7 @@ def main(args): assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct" # actual prediction function - shared between train and validate - def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None): + def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None, embedding_perturbation=0.0): with torch.no_grad(): with autocast(enabled=args.amp): pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device) @@ -968,6 +968,11 @@ def main(args): else: encoder_hidden_states = encoder_hidden_states.last_hidden_state + # https://arxiv.org/pdf/2405.20494 + perturbation_deviation = embedding_perturbation / math.sqrt(encoder_hidden_states.shape[2]) + perturbation_delta = torch.randn_like(encoder_hidden_states) * (perturbation_deviation) + encoder_hidden_states = encoder_hidden_states + perturbation_delta + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) if noise_scheduler.config.prediction_type == "epsilon": @@ -1195,7 +1200,8 @@ def main(args): batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True, - loss_scale=batch["loss_scale"]) + loss_scale=batch["loss_scale"], + embedding_perturbation=args.embedding_perturbation) del target, model_pred @@ -1362,6 +1368,7 @@ if __name__ == "__main__": argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables automatic mixed precision (def: False)") argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)") argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED") + argparser.add_argument("--embedding_perturbation", type=float, default=0.0, help="random perturbation of text embeddings (def: 0.0)") argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!") argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids") argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)") diff --git a/windows_setup.cmd b/windows_setup.cmd index a300cf6..7a565d3 100644 --- a/windows_setup.cmd +++ b/windows_setup.cmd @@ -26,6 +26,7 @@ pip install prodigyopt pip install torchsde pip install peft>=0.9.0 pip install unidecode +pip install tiktoken python utils/get_yamls.py GOTO :eof