update caption to work with cog2 and glm-9v, add embedding_perturbation

This commit is contained in:
Victor Hall 2024-06-09 01:25:23 -04:00
parent d96b9cc56e
commit beec38726a
13 changed files with 475 additions and 121 deletions

View File

@ -47,7 +47,8 @@ except ImportError:
Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit
IMAGE_SIZE: int = 490
IMAGE_SIZE_COG1: int = 490
IMAGE_SIZE_COG2: int = 1344
PATCH_SIZE: int = 14
torch.backends.cuda.matmul.allow_tf32 = True
@ -70,7 +71,7 @@ def save_params(args, gen_kwargs):
with open(save_path, "w") as f:
f.write(pretty_print)
def create_bnb_config(bnb_4bit_compute_dtype="bfloat16",bnb_4bit_quant_type= "fp4"):
def create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type= "fp4"):
return BitsAndBytesConfig(
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
bnb_4bit_quant_type=bnb_4bit_quant_type,
@ -89,15 +90,26 @@ class BaseModelWrapper:
self.model_name = model_name
logging.info(f"Loading {model_name}")
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
def load_model(self, dtype: str="auto"):
bnb_config = self._maybe_create_bnb_config(dtype)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
quantization_config = bnb_config
).to(0)
self.tokenizer = AutoProcessor.from_pretrained(self.model_name)
return self.model, self.tokenizer
def _maybe_create_bnb_config(self, dtype, auto_bnb=True, auto_bnb_dtype="fp4"):
bnb_config = None
if dtype == "auto":
if auto_bnb:
bnb_config = create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type=auto_bnb_dtype)
if dtype in ["nf4", "fp4"]:
bnb_config = create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type=dtype)
return bnb_config
def get_gen_kwargs(self, args):
gen_kwargs = {
@ -129,47 +141,7 @@ class BaseModelWrapper:
else:
logging.debug(f"** Sampling enabled")
return gen_kwargs
def caption(prompt, args):
return ""
class XtunerLlavaModelManager(BaseModelWrapper):
def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"):
self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers"
super().__init__(model_name)
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
self.model = LlavaForConditionalGeneration.from_pretrained(
#self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
#quantization_config=create_bnb_config()
).to(0)
self.processor = LlavaProcessor.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
print(f"self.tokenizer: {self.tokenizer}")
# tokens = self.tokenizer("foo")
# print(f"foo tokens test1: {tokens}")
return self.model, self.tokenizer
def get_inputs(self, image: Image.Image, prompt: str):
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
return inputs
def _build_conversational_input_ids(self, prompt, starts_with):
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
def _truncate_to_whole_sentences(self, caption):
# model does not stop generating cleanly and cuts off mid sentence
caption = caption.split(".")
caption = ". ".join(caption[0:-1]) + "."
caption = caption.replace("\n","")
caption = caption.replace(" "," ")
return caption
def _clean_caption(self, caption, args):
"""
Removes some nonsense Llava adds.
@ -194,11 +166,52 @@ class XtunerLlavaModelManager(BaseModelWrapper):
caption = caption.replace(", who is the main subject of the photo.", ".")
caption = caption.replace(", who is the main subject.", ".")
caption = caption.replace("who is the main subject.", ".")
caption = caption.replace(", who is the central focus of the composition.", ".")
caption = caption.replace("who is the central focus of the composition.", ".")
caption = self._truncate_to_whole_sentences(caption)
logging.debug(f"**Llava post-cleaning caption: {caption}")
return caption
def caption(prompt, args):
return ""
class XtunerLlavaModelManager(BaseModelWrapper):
def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"):
self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers"
super().__init__(model_name)
logging.info("Loading Xtuner Llava-Llama3 model...")
def load_model(self, dtype="auto"):
bnb_config = self._maybe_create_bnb_config(dtype, auto_bnb=False)
self.model = LlavaForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
quantization_config=bnb_config
).to("cuda")
self.processor = LlavaProcessor.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
return self.model, self.tokenizer
def get_inputs(self, image: Image.Image, prompt: str):
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
return inputs
def _build_conversational_input_ids(self, prompt, starts_with):
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
def _truncate_to_whole_sentences(self, caption):
# model does not stop generating cleanly and cuts off mid sentence
caption = caption.split(".")
caption = ". ".join(caption[0:-1]) + "."
caption = caption.replace("\n","")
caption = caption.replace(" "," ")
return caption
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
gen_kwargs = self.get_gen_kwargs(args)
@ -227,59 +240,110 @@ class XtunerLlavaModelManager(BaseModelWrapper):
caption = self._clean_caption(caption, args)
return caption
class MoaiManager:
# class MoaiManager:
# def __init__(self, model_name: str):
# self.model_name = model_name
# self.moai_model = None
# self.moai_processor = None
# self.seg_model = None
# self.seg_processor = None
# self.od_model = None
# self.od_processor = None
# self.sgg_model = None
# self.ocr_model = None
# logging.info("Loading Moai model...")
# def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
# moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
# = prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
# self.moai_model = moai_model
# self.moai_processor = moai_processor
# self.seg_model = seg_model
# self.seg_processor = seg_processor
# self.od_model = od_model
# self.od_processor = od_processor
# self.sgg_model = sgg_model
# self.ocr_model = ocr_model
# return moai_model, moai_processor
# def get_inputs(self, image: Image.Image, prompt: str):
# moai_inputs = self.moai_model.demo_process(image=image,
# prompt=prompt,
# processor=self.moai_processor,
# seg_model=self.seg_model,
# seg_processor=self.seg_processor,
# od_model=self.od_model,
# od_processor=self.od_processor,
# sgg_model=self.sgg_model,
# ocr_model=self.ocr_model,
# device="cuda:0")
# return moai_inputs
class CogGLMManager(BaseModelWrapper):
def __init__(self, model_name: str):
self.model_name = model_name
self.moai_model = None
self.moai_processor = None
self.seg_model = None
self.seg_processor = None
self.od_model = None
self.od_processor = None
self.sgg_model = None
self.ocr_model = None
super().__init__(model_name)
if not model_name:
self.model_name = "THUDM/cogglm-6b"
else:
self.model_name = model_name
logging.info("Loading CogGLM model...")
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
= prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
self.moai_model = moai_model
self.moai_processor = moai_processor
self.seg_model = seg_model
self.seg_processor = seg_processor
self.od_model = od_model
self.od_processor = od_processor
self.sgg_model = sgg_model
self.ocr_model = ocr_model
def load_model(self, dtype: str = "auto"):
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
bnb_config = None
if dtype in ["auto","nf4"]:
bnb_config = create_bnb_config()
self.model = model = AutoModelForCausalLM.from_pretrained(
"THUDM/glm-4v-9b",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
quantization_config=bnb_config
).eval()
if bnb_config is None:
# if BNB is used it is automatically sent to cuda device, otherwise need to move it manually
self.model = model.to("cuda")
return moai_model, moai_processor
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
gen_kwargs = self.get_gen_kwargs(args)
def get_inputs(self, image: Image.Image, prompt: str):
moai_inputs = self.moai_model.demo_process(image=image,
prompt=prompt,
processor=self.moai_processor,
seg_model=self.seg_model,
seg_processor=self.seg_processor,
od_model=self.od_model,
od_processor=self.od_processor,
sgg_model=self.sgg_model,
ocr_model=self.ocr_model,
device="cuda:0")
return moai_inputs
inputs = self.tokenizer.apply_chat_template([{"role": "user", "image": image, "content": prompt}],
add_generation_prompt=True, tokenize=True, return_tensors="pt",
return_dict=True)
inputs.to("cuda")
# def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any:
# with torch.inference_mode():
# generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache)
# answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0]
# return answer
outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
len_inputs = inputs['input_ids'].shape[1]
outputs_without_prompt = outputs[:, len_inputs:]
caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True)
return caption
class CogVLMManager(BaseModelWrapper):
def __init__(self, model_name: str):
super().__init__(model_name)
self.model_name = "THUDM/cogvlm-chat-hf"
if not model_name:
self.model_name = "THUDM/cogvlm-chat-hf"
self.cog_version = 1
elif model_name.lower() == "THUDM/cogvlm2-llama3-chat-19b".lower():
self.model_name = "THUDM/cogvlm2-llama3-chat-19b"
self.cog_version = 2
else:
self.model_name = model_name
self.cog_version = 1
patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions
logging.info("Loading CogVLM model...")
def load_model(self):
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
def load_model(self, dtype: str = "auto"):
if self.model_name.lower() == "THUDM/cogvlm-chat-hf".lower():
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
elif self.model_name.lower() == "THUDM/cogvlm2-llama3-chat-19b".lower():
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
self.tokenizer.pad_token_id = 128002 # for Llama 3
else:
raise ValueError("Unknown model name")
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
@ -297,7 +361,7 @@ class CogVLMManager(BaseModelWrapper):
starts_with: Optional[str] = None,
):
# based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py
image_size: int = IMAGE_SIZE
image_size: int = IMAGE_SIZE_COG2 if self.cog_version == 2 else IMAGE_SIZE_COG1
patch_size: int = PATCH_SIZE
assert images is None or len(images) <= 1, f"not support multi images by now."
history = history or []
@ -306,9 +370,8 @@ class CogVLMManager(BaseModelWrapper):
text += starts_with if starts_with is not None else ""
input_ids = [self.tokenizer.bos_token_id]
token_type_ids = [0]
token_type_ids = [0] # LANGUAGE_TOKEN_TYPE
if images is not None and len(images) == 1:
# vision
transform = transforms.Compose(
[
transforms.Resize(
@ -319,7 +382,11 @@ class CogVLMManager(BaseModelWrapper):
]
)
images = [transform(images[0])]
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
if self.cog_version == 1:
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
elif self.cog_version == 2:
vision_token_num = (image_size // patch_size // 2) * (image_size // patch_size // 2) + 2
input_ids += [self.tokenizer.pad_token_id] * vision_token_num
token_type_ids += [1] * vision_token_num
text_ids = self.tokenizer.encode(text, add_special_tokens=False)
@ -340,10 +407,6 @@ class CogVLMManager(BaseModelWrapper):
inputs = self._build_conversation_input_ids(query=prompt, history=history, images=[image], starts_with=args.starts_with)
# inputs['input_ids'].shape: torch.Size([1259])
# inputs['attention_mask'].shape: torch.Size([1259])
# inputs['images'][0].shape: torch.Size([3, 490, 490])
inputs = {
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
"token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"),
@ -352,15 +415,8 @@ class CogVLMManager(BaseModelWrapper):
"output_hidden_states": True,
"return_dict": True
}
# inputs['input_ids'].shape: torch.Size([1, 1259])
# inputs['attention_mask'].shape: torch.Size([1, 1259])
# inputs['images'][0][0].shape: torch.Size([3, 490, 490])
# len(inputs['images'][0]): 1
# len(inputs['images'][0][0]): 3
outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
#print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}")
#print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}")
len_inputs = inputs['input_ids'].shape[1]
outputs_without_prompt = outputs[:, len_inputs:]
@ -369,12 +425,22 @@ class CogVLMManager(BaseModelWrapper):
return caption
def get_model_wrapper(model_name: str):
if "moai" in model_name:
return MoaiManager(model_name)
elif "llava" in model_name:
return XtunerLlavaModelManager(model_name)
else:
return CogVLMManager(model_name)
match model_name.casefold():
# case x if "moai" in x:
# #return MoaiManager(model_name)
# return None
case x if "llava" in x:
return XtunerLlavaModelManager(model_name)
case "thudm/glm-4v-9b":
return CogGLMManager(model_name)
case "thudm/cogvlm2-llama3-chat-19b":
return CogVLMManager(model_name)
case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]:
return CogVLMManager(model_name)
case None:
return CogVLMManager(model_name)
case _:
raise ValueError(f"Model {model_name} not supported")
def get_inputs_dict(inputs):
inputs = {
@ -518,6 +584,7 @@ if __name__ == "__main__":
argparser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch processing. Does NOT work with COG! (def: 1)")
argparser.add_argument("--debug", action="store_true", help="Enable debug logging")
argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.")
argparser.add_argument("--dtype", choices=["auto","fp16","bf16","nf4","fp4"], default="auto", help="Data type for inference (def: auto, see docs)")
argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling")
argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)")
argparser.add_argument("--top_k", type=int, default=None, help="Top-k, filter k highest probability tokens before sampling")
@ -540,7 +607,7 @@ if __name__ == "__main__":
argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.")
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.")
argparser.add_argument("--model", type=str, default=None, help="Model to use for captioning.")
argparser.add_argument("--min_pixels", type=int, default=1, help="Minimum total pixel size to caption, under the limit will be skipped")
args, unknown_args = argparser.parse_known_args()

View File

@ -8,13 +8,23 @@ It is capable of naming and identifying things with proper nouns and has a large
<a href="https://colab.research.google.com/github/nawnie/EveryDream2trainer/blob/main/CaptionCog.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
Both the ([Vicuna-based](https://huggingface.co/THUDM/cogvlm-chat-hf)) and ([Llama3-based](https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B)) models are supported.
Choose these by using one of these two CLI args:
--model THUDM/cogvlm-chat-hf
--model THUDM/cogvlm2-llama3-chat-19B
The script uses the Vicuna model (first) by default if no `--model` arg is specified.
## Llava update
This script now (confusiningly) supports (Xtuner's Llava Llama3 8b v1.1)[https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main].
To use, add `--model "https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main"` to your command line.
This is a work in progress. So far it seems bad_words do not work.
When using Llava, the script will perform some clean-up operations to remove some less-than-useful language from the caption because the bad_words part of the Hugginface Transformers API is not supported by Llava.
## Basics

View File

@ -5,7 +5,7 @@ LABEL org.opencontainers.image.licenses="AGPL-3.0-only"
ARG DEBIAN_FRONTEND=noninteractive
# Don't write .pyc bytecode
ENV PYTHONDONTWRITEBYTECODE=1
# ENV PYTHONDONTWRITEBYTECODE=1
# Create workspace working directory
RUN mkdir /build
@ -49,7 +49,7 @@ ENV DEBIAN_FRONTEND noninteractive\
ENV PYTHONUNBUFFERED=1
# Don't write .pyc bytecode
ENV PYTHONDONTWRITEBYTECODE=1
# ENV PYTHONDONTWRITEBYTECODE=1
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \
@ -65,7 +65,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
echo "en_US.UTF-8 UTF-8" > /etc/locale.gen
# Install runpodctl
RUN wget https://github.com/runpod/runpodctl/releases/download/v1.9.0/runpodctl-linux-amd -O runpodctl && \
RUN wget https://github.com/runpod/runpodctl/releases/download/v1.14.3/runpodctl-linux-amd64 -O runpodctl && \
chmod a+x runpodctl && \
mv runpodctl /usr/local/bin

View File

@ -1,4 +1,4 @@
diffusers[torch]>=0.21.4
diffusers[torch]>=0.27.2
ninja
numpy
omegaconf==2.2.3

View File

@ -19,4 +19,5 @@ safetensors
prodigyopt
torchsde
peft==0.9.0
unidecode
unidecode
tiktoken

View File

@ -22,4 +22,5 @@ numpy==1.23.5
wandb
colorama
safetensors
torchsde
torchsde
tiktoken

View File

@ -26,6 +26,7 @@ def main(args):
except Exception as e:
print(f"FAILED: {path}")
failed.append((path,e))
if not failed:
print("No errors found")
else:
@ -33,7 +34,6 @@ def main(args):
for path, e in failed:
print(f"FAILED: {path} {e}")
if __name__ == '__main__':
print("This script checks that all images in a directory are valid.")
print("If any errors occur, they will be printed out at the end.")

88
scripts/mt_grid.py Normal file
View File

@ -0,0 +1,88 @@
from diffusers import StableDiffusionPipeline
import torch
from torch.cuda.amp import autocast
import os
import argparse
from PIL import Image, ImageDraw, ImageFont
def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen,
steps: int = 30, batch_size: int = 1):
"""
generates a single sample at a given cfg scale and saves it to disk
"""
with autocast():
images = pipe(prompt,
num_inference_steps=steps,
num_images_per_prompt=batch_size,
guidance_scale=cfg,
generator=gen,
height=height,
width=width,
).images
return images
def generate_simple(prompt, model):
pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda")
images = __generate_sample(pipe, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)
return images[0]
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--epochs", type=int, default=60)
args = argparser.parse_args()
epochs = args.epochs
path = "/mnt/nvme/mt/val"
model = None
if epochs == 100:
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep100-gs159000"
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep100-gs159000"
elif epochs == 80:
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep80-gs127200"
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep80-gs127200"
elif epochs == 60:
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400"
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep60-gs95400"
else:
raise ValueError("epochs must be 100, 80, or 60")
pipe1 = StableDiffusionPipeline.from_pretrained(model1).to("cuda")
pipe2 = StableDiffusionPipeline.from_pretrained(model2).to("cuda")
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith(".txt") and not file.endswith("file_list.txt"):
txt_path = os.path.join(root, file)
with open(txt_path, "r", encoding="utf-8") as f:
prompt = f.readline()
generated_image1 = __generate_sample(pipe1, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0]
short_prompt = prompt.split(",")[0]
generated_image2 = __generate_sample(pipe2, short_prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0]
print(short_prompt)
gt_path = txt_path.replace(".txt", ".png")
print(f"Loading gt_path {gt_path}")
ground_truth_image = Image.open(gt_path)
ground_truth_image = ground_truth_image.resize((512, 512))
combined_image = Image.new("RGB", (1536, 576), color=(96, 96, 96))
combined_image.paste(ground_truth_image, (0, 0))
combined_image.paste(generated_image1, (512, 0))
combined_image.paste(generated_image2, (1024, 0))
draw = ImageDraw.Draw(combined_image)
font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 18)
draw.text((0, 510), f"epochs={epochs}", font=font)
draw.text((200, 510), "↑ ground truth ↑", font=font)
draw.text((650, 510), "↑ trained&generated full caption↑", font=font)
draw.text((1140, 510), "↑ trained&generated short caption ↑", font=font)
font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 24)
draw.text((100, 536), prompt, font=font)
draw.text((1240, 537), short_prompt, font=font)
output_path = os.path.join("/mnt/nvme/mt", str(epochs), f"{file}_compare.png")
print(f"Saving to {output_path}")
combined_image.save(output_path)

66
scripts/split_val.py Normal file
View File

@ -0,0 +1,66 @@
"""
Copyright [2022] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
## Script to move random files from source_root to destination_root for use as validation split
## chooses (num_pairs_to_move) image/caption pairs from each subfolder in source_root and moves them to destination_root, preserving subfolder structure
## also creates a validation_captions.txt file in destination_root for inference testing later to see how model performs on unseen data
## only works on (png|jpg) + txt pairs. Will break if there is no .txt file or images are other extensions
import os
import shutil
import random
source_root = "/mnt/nvme/mydata_train"
destination_root = "/mnt/nvme/mydata_val"
num_pairs_to_move = 3 # TODO: also do % based moving instead of fixed number
def move_random_file_pairs(source_folder, destination_folder):
with open(os.path.join(destination_folder, "validation_captions.txt"), "w") as f:
for subdir, dirs, files in os.walk(source_folder):
for dir in dirs:
source_subfolder = os.path.join(subdir, dir)
destination_subfolder = os.path.join(destination_folder, dir)
if not os.path.exists(destination_subfolder):
os.makedirs(destination_subfolder)
file_list = [f for f in os.listdir(source_subfolder) if f.endswith((".png")) or f.endswith((".jpg"))]
if len(file_list) >= num_pairs_to_move:
random.shuffle(file_list)
for i in range(num_pairs_to_move):
file_name = file_list[i]
source_file = os.path.join(source_subfolder, file_name)
destination_file = os.path.join(destination_subfolder, file_name)
caption_file_name = os.path.splitext(file_name)[0] + ".txt"
caption_source_file = os.path.join(source_subfolder, caption_file_name)
caption_destination_file = os.path.join(destination_subfolder, caption_file_name)
with open(caption_source_file, "r") as caption_source:
caption = caption_source.readline()
f.write(caption)
f.write("\n")
print(f"Moving {caption_source_file} to {caption_destination_file}")
print(f"Moving {source_file} to {destination_file}")
print(f"Caption: {caption}\n")
shutil.move(source_file, destination_file)
shutil.move(caption_source_file, caption_destination_file)
move_random_file_pairs(source_root, destination_root)

View File

@ -0,0 +1,113 @@
from diffusers import StableDiffusionPipeline
import torch
from torch.cuda.amp import autocast
import os
import argparse
from PIL import Image
def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen,
steps: int = 30, batch_size: int = 1):
"""
generates a single sample at a given cfg scale and saves it to disk
"""
with autocast():
images = pipe(prompt,
num_inference_steps=steps,
num_images_per_prompt=batch_size,
guidance_scale=cfg,
generator=gen,
height=height,
width=width,
).images
return images
def simple():
pipe = StableDiffusionPipeline.from_pretrained("/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400").to("cuda")
images = __generate_sample(pipe, "bicycle", cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)
images[0].save("test.png")
if __name__ == "__main__":
#simple()
argparser = argparse.ArgumentParser()
argparser.add_argument("--prompt_file", type=str, required=False)
argparser.add_argument("--val_data_path", type=str, default="/mnt/nvme/mt/val", required=False)
argparser.add_argument("--models", nargs="+", help="names of models")
args = argparser.parse_args()
args.val_data_path = "/mnt/nvme/mt/val/00b42"
W = 512
H = 512
BATCH_SIZE = 4
print(f"Generating grid image for {len(args.models)} models and {args.prompt_file}")
#print each args.models
print("Models:")
for m in args.models:
print(f" {m}")
# with open(args.prompt_file, "r") as f:
# prompt_master_list = []
# for x, line in enumerate(f):
# prompt_master_list.append(line.strip())
# open the txt files in args.val_data_path
prompt_master_list = {}
for f in os.listdir(args.val_data_path):
if f.endswith(".txt"):
txt_path = os.path.join(args.val_data_path, f)
with open(os.path.join(args.val_data_path, f), "r", encoding="utf-8") as f2:
img_path = os.path.splitext(f)[0] + ".png"
img_path = os.path.join(args.val_data_path, img_path)
prompt_master_list[img_path] = f2.readline().strip()
print(f"Found {len(prompt_master_list)} images in {args.val_data_path}")
print(f"First 10 images: {list(prompt_master_list.values())[:10]}")
print()
num_lines = len(prompt_master_list)
grid_h = (num_lines + 1) * W # num images plus blank for left column labels
grid_w = (1 + len(args.models)) * H # num models plus blank for top row labels
grid_img = Image.new("RGB", (grid_w, grid_h))
#num_iterations = len(prompt_master_list) // BATCH_SIZE + (len(prompt_master_list) % BATCH_SIZE > 0)
chunked_dict_list = []
chunk = {}
for key, value in prompt_master_list.items():
chunk[key] = value
if len(chunk) == BATCH_SIZE:
chunked_dict_list.append(chunk)
chunk = {}
# Append any remaining items if the total number of items is not a multiple of chunk_size
if chunk:
chunked_dict_list.append(chunk)
# Iterate through the chunks
for i, chunk in enumerate(chunked_dict_list):
print(f"Chunk {i + 1}: {chunk}")
exit()
for i_m, model in enumerate(args.models):
for j_p in range(chunked_dict_list):
start_index = j_p * BATCH_SIZE
end_index = (j_p + 1) * BATCH_SIZE
current_prompts = prompt_master_list[start_index:end_index]
print(f"{model}: {current_prompts}")
print()
if True:
pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda")
seed_generator = torch.Generator(pipe.device).manual_seed(555)
images = __generate_sample(pipe, current_prompts, cfg=7.5, height=512, width=512, gen=seed_generator, steps=40, batch_size=BATCH_SIZE)
# paste each image into the grid starting from H,W and incrementing by W
for k, k_img in enumerate(images):
k_img.save(f"tmp/{i_m}_{k}.png")
grid_img.paste(k_img, (W+k*W, H+H*i_m))
# save the grid image
grid_img.save(f"tmp/grid.png")

View File

@ -8,6 +8,7 @@
"data_root": "/mnt/q/training_samples/ff7r/man",
"disable_amp": false,
"disable_textenc_training": false,
"embedding_perturbation": 0.0,
"flip_p": 0.0,
"gpuid": 0,
"gradient_checkpointing": true,
@ -28,8 +29,7 @@
"save_ckpts_from_n_epochs": 0,
"save_every_n_epochs": 20,
"save_optimizer": false,
"scale_lr": false,
"seed": 555,
"seed": -1,
"timestep_start": 0,
"timestep_end": 1000,
"shuffle_tags": false,

View File

@ -733,6 +733,7 @@ def main(args):
try:
print()
# currently broken on most systems?
#unet = torch.compile(unet, mode="max-autotune")
#text_encoder = torch.compile(text_encoder, mode="max-autotune")
#vae = torch.compile(vae, mode="max-autotune")
@ -833,7 +834,6 @@ def main(args):
logging.info(
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_update_interval: {args.ema_update_interval}, ema_device: {args.ema_device}.")
ed_optimizer = EveryDreamOptimizer(args,
optimizer_config,
text_encoder,
@ -931,7 +931,7 @@ def main(args):
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None):
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None, embedding_perturbation=0.0):
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
@ -968,6 +968,11 @@ def main(args):
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
# https://arxiv.org/pdf/2405.20494
perturbation_deviation = embedding_perturbation / math.sqrt(encoder_hidden_states.shape[2])
perturbation_delta = torch.randn_like(encoder_hidden_states) * (perturbation_deviation)
encoder_hidden_states = encoder_hidden_states + perturbation_delta
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if noise_scheduler.config.prediction_type == "epsilon":
@ -1195,7 +1200,8 @@ def main(args):
batch["tokens"],
args.zero_frequency_noise_ratio,
return_loss=True,
loss_scale=batch["loss_scale"])
loss_scale=batch["loss_scale"],
embedding_perturbation=args.embedding_perturbation)
del target, model_pred
@ -1362,6 +1368,7 @@ if __name__ == "__main__":
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables automatic mixed precision (def: False)")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
argparser.add_argument("--embedding_perturbation", type=float, default=0.0, help="random perturbation of text embeddings (def: 0.0)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")

View File

@ -26,6 +26,7 @@ pip install prodigyopt
pip install torchsde
pip install peft>=0.9.0
pip install unidecode
pip install tiktoken
python utils/get_yamls.py
GOTO :eof