update caption to work with cog2 and glm-9v, add embedding_perturbation

This commit is contained in:
Victor Hall 2024-06-09 01:25:23 -04:00
parent d96b9cc56e
commit beec38726a
13 changed files with 475 additions and 121 deletions

View File

@ -47,7 +47,8 @@ except ImportError:
Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit
IMAGE_SIZE: int = 490 IMAGE_SIZE_COG1: int = 490
IMAGE_SIZE_COG2: int = 1344
PATCH_SIZE: int = 14 PATCH_SIZE: int = 14
torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True
@ -70,7 +71,7 @@ def save_params(args, gen_kwargs):
with open(save_path, "w") as f: with open(save_path, "w") as f:
f.write(pretty_print) f.write(pretty_print)
def create_bnb_config(bnb_4bit_compute_dtype="bfloat16",bnb_4bit_quant_type= "fp4"): def create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type= "fp4"):
return BitsAndBytesConfig( return BitsAndBytesConfig(
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
bnb_4bit_quant_type=bnb_4bit_quant_type, bnb_4bit_quant_type=bnb_4bit_quant_type,
@ -89,15 +90,26 @@ class BaseModelWrapper:
self.model_name = model_name self.model_name = model_name
logging.info(f"Loading {model_name}") logging.info(f"Loading {model_name}")
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"): def load_model(self, dtype: str="auto"):
bnb_config = self._maybe_create_bnb_config(dtype)
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_name, self.model_name,
torch_dtype=torch.float16, torch_dtype=torch.float16,
low_cpu_mem_usage=True, low_cpu_mem_usage=True,
quantization_config = bnb_config
).to(0) ).to(0)
self.tokenizer = AutoProcessor.from_pretrained(self.model_name) self.tokenizer = AutoProcessor.from_pretrained(self.model_name)
return self.model, self.tokenizer return self.model, self.tokenizer
def _maybe_create_bnb_config(self, dtype, auto_bnb=True, auto_bnb_dtype="fp4"):
bnb_config = None
if dtype == "auto":
if auto_bnb:
bnb_config = create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type=auto_bnb_dtype)
if dtype in ["nf4", "fp4"]:
bnb_config = create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type=dtype)
return bnb_config
def get_gen_kwargs(self, args): def get_gen_kwargs(self, args):
gen_kwargs = { gen_kwargs = {
@ -129,47 +141,7 @@ class BaseModelWrapper:
else: else:
logging.debug(f"** Sampling enabled") logging.debug(f"** Sampling enabled")
return gen_kwargs return gen_kwargs
def caption(prompt, args):
return ""
class XtunerLlavaModelManager(BaseModelWrapper):
def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"):
self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers"
super().__init__(model_name)
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
self.model = LlavaForConditionalGeneration.from_pretrained(
#self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
#quantization_config=create_bnb_config()
).to(0)
self.processor = LlavaProcessor.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
print(f"self.tokenizer: {self.tokenizer}")
# tokens = self.tokenizer("foo")
# print(f"foo tokens test1: {tokens}")
return self.model, self.tokenizer
def get_inputs(self, image: Image.Image, prompt: str):
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
return inputs
def _build_conversational_input_ids(self, prompt, starts_with):
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
def _truncate_to_whole_sentences(self, caption):
# model does not stop generating cleanly and cuts off mid sentence
caption = caption.split(".")
caption = ". ".join(caption[0:-1]) + "."
caption = caption.replace("\n","")
caption = caption.replace(" "," ")
return caption
def _clean_caption(self, caption, args): def _clean_caption(self, caption, args):
""" """
Removes some nonsense Llava adds. Removes some nonsense Llava adds.
@ -194,11 +166,52 @@ class XtunerLlavaModelManager(BaseModelWrapper):
caption = caption.replace(", who is the main subject of the photo.", ".") caption = caption.replace(", who is the main subject of the photo.", ".")
caption = caption.replace(", who is the main subject.", ".") caption = caption.replace(", who is the main subject.", ".")
caption = caption.replace("who is the main subject.", ".") caption = caption.replace("who is the main subject.", ".")
caption = caption.replace(", who is the central focus of the composition.", ".")
caption = caption.replace("who is the central focus of the composition.", ".")
caption = self._truncate_to_whole_sentences(caption) caption = self._truncate_to_whole_sentences(caption)
logging.debug(f"**Llava post-cleaning caption: {caption}") logging.debug(f"**Llava post-cleaning caption: {caption}")
return caption return caption
def caption(prompt, args):
return ""
class XtunerLlavaModelManager(BaseModelWrapper):
def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"):
self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers"
super().__init__(model_name)
logging.info("Loading Xtuner Llava-Llama3 model...")
def load_model(self, dtype="auto"):
bnb_config = self._maybe_create_bnb_config(dtype, auto_bnb=False)
self.model = LlavaForConditionalGeneration.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
quantization_config=bnb_config
).to("cuda")
self.processor = LlavaProcessor.from_pretrained(self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
return self.model, self.tokenizer
def get_inputs(self, image: Image.Image, prompt: str):
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
return inputs
def _build_conversational_input_ids(self, prompt, starts_with):
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
def _truncate_to_whole_sentences(self, caption):
# model does not stop generating cleanly and cuts off mid sentence
caption = caption.split(".")
caption = ". ".join(caption[0:-1]) + "."
caption = caption.replace("\n","")
caption = caption.replace(" "," ")
return caption
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]): def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
gen_kwargs = self.get_gen_kwargs(args) gen_kwargs = self.get_gen_kwargs(args)
@ -227,59 +240,110 @@ class XtunerLlavaModelManager(BaseModelWrapper):
caption = self._clean_caption(caption, args) caption = self._clean_caption(caption, args)
return caption return caption
class MoaiManager: # class MoaiManager:
# def __init__(self, model_name: str):
# self.model_name = model_name
# self.moai_model = None
# self.moai_processor = None
# self.seg_model = None
# self.seg_processor = None
# self.od_model = None
# self.od_processor = None
# self.sgg_model = None
# self.ocr_model = None
# logging.info("Loading Moai model...")
# def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
# moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
# = prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
# self.moai_model = moai_model
# self.moai_processor = moai_processor
# self.seg_model = seg_model
# self.seg_processor = seg_processor
# self.od_model = od_model
# self.od_processor = od_processor
# self.sgg_model = sgg_model
# self.ocr_model = ocr_model
# return moai_model, moai_processor
# def get_inputs(self, image: Image.Image, prompt: str):
# moai_inputs = self.moai_model.demo_process(image=image,
# prompt=prompt,
# processor=self.moai_processor,
# seg_model=self.seg_model,
# seg_processor=self.seg_processor,
# od_model=self.od_model,
# od_processor=self.od_processor,
# sgg_model=self.sgg_model,
# ocr_model=self.ocr_model,
# device="cuda:0")
# return moai_inputs
class CogGLMManager(BaseModelWrapper):
def __init__(self, model_name: str): def __init__(self, model_name: str):
self.model_name = model_name super().__init__(model_name)
self.moai_model = None if not model_name:
self.moai_processor = None self.model_name = "THUDM/cogglm-6b"
self.seg_model = None else:
self.seg_processor = None self.model_name = model_name
self.od_model = None logging.info("Loading CogGLM model...")
self.od_processor = None
self.sgg_model = None
self.ocr_model = None
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"): def load_model(self, dtype: str = "auto"):
moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
= prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype) bnb_config = None
self.moai_model = moai_model if dtype in ["auto","nf4"]:
self.moai_processor = moai_processor bnb_config = create_bnb_config()
self.seg_model = seg_model self.model = model = AutoModelForCausalLM.from_pretrained(
self.seg_processor = seg_processor "THUDM/glm-4v-9b",
self.od_model = od_model torch_dtype=torch.bfloat16,
self.od_processor = od_processor low_cpu_mem_usage=True,
self.sgg_model = sgg_model trust_remote_code=True,
self.ocr_model = ocr_model quantization_config=bnb_config
).eval()
if bnb_config is None:
# if BNB is used it is automatically sent to cuda device, otherwise need to move it manually
self.model = model.to("cuda")
return moai_model, moai_processor def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
gen_kwargs = self.get_gen_kwargs(args)
def get_inputs(self, image: Image.Image, prompt: str): inputs = self.tokenizer.apply_chat_template([{"role": "user", "image": image, "content": prompt}],
moai_inputs = self.moai_model.demo_process(image=image, add_generation_prompt=True, tokenize=True, return_tensors="pt",
prompt=prompt, return_dict=True)
processor=self.moai_processor, inputs.to("cuda")
seg_model=self.seg_model,
seg_processor=self.seg_processor,
od_model=self.od_model,
od_processor=self.od_processor,
sgg_model=self.sgg_model,
ocr_model=self.ocr_model,
device="cuda:0")
return moai_inputs
# def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any: outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
# with torch.inference_mode():
# generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache) len_inputs = inputs['input_ids'].shape[1]
# answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0] outputs_without_prompt = outputs[:, len_inputs:]
# return answer
caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True)
return caption
class CogVLMManager(BaseModelWrapper): class CogVLMManager(BaseModelWrapper):
def __init__(self, model_name: str): def __init__(self, model_name: str):
super().__init__(model_name) super().__init__(model_name)
self.model_name = "THUDM/cogvlm-chat-hf" if not model_name:
self.model_name = "THUDM/cogvlm-chat-hf"
self.cog_version = 1
elif model_name.lower() == "THUDM/cogvlm2-llama3-chat-19b".lower():
self.model_name = "THUDM/cogvlm2-llama3-chat-19b"
self.cog_version = 2
else:
self.model_name = model_name
self.cog_version = 1
patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions
logging.info("Loading CogVLM model...")
def load_model(self): def load_model(self, dtype: str = "auto"):
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5") if self.model_name.lower() == "THUDM/cogvlm-chat-hf".lower():
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
elif self.model_name.lower() == "THUDM/cogvlm2-llama3-chat-19b".lower():
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
self.tokenizer.pad_token_id = 128002 # for Llama 3
else:
raise ValueError("Unknown model name")
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
self.model_name, self.model_name,
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
@ -297,7 +361,7 @@ class CogVLMManager(BaseModelWrapper):
starts_with: Optional[str] = None, starts_with: Optional[str] = None,
): ):
# based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py # based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py
image_size: int = IMAGE_SIZE image_size: int = IMAGE_SIZE_COG2 if self.cog_version == 2 else IMAGE_SIZE_COG1
patch_size: int = PATCH_SIZE patch_size: int = PATCH_SIZE
assert images is None or len(images) <= 1, f"not support multi images by now." assert images is None or len(images) <= 1, f"not support multi images by now."
history = history or [] history = history or []
@ -306,9 +370,8 @@ class CogVLMManager(BaseModelWrapper):
text += starts_with if starts_with is not None else "" text += starts_with if starts_with is not None else ""
input_ids = [self.tokenizer.bos_token_id] input_ids = [self.tokenizer.bos_token_id]
token_type_ids = [0] token_type_ids = [0] # LANGUAGE_TOKEN_TYPE
if images is not None and len(images) == 1: if images is not None and len(images) == 1:
# vision
transform = transforms.Compose( transform = transforms.Compose(
[ [
transforms.Resize( transforms.Resize(
@ -319,7 +382,11 @@ class CogVLMManager(BaseModelWrapper):
] ]
) )
images = [transform(images[0])] images = [transform(images[0])]
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2 if self.cog_version == 1:
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
elif self.cog_version == 2:
vision_token_num = (image_size // patch_size // 2) * (image_size // patch_size // 2) + 2
input_ids += [self.tokenizer.pad_token_id] * vision_token_num input_ids += [self.tokenizer.pad_token_id] * vision_token_num
token_type_ids += [1] * vision_token_num token_type_ids += [1] * vision_token_num
text_ids = self.tokenizer.encode(text, add_special_tokens=False) text_ids = self.tokenizer.encode(text, add_special_tokens=False)
@ -340,10 +407,6 @@ class CogVLMManager(BaseModelWrapper):
inputs = self._build_conversation_input_ids(query=prompt, history=history, images=[image], starts_with=args.starts_with) inputs = self._build_conversation_input_ids(query=prompt, history=history, images=[image], starts_with=args.starts_with)
# inputs['input_ids'].shape: torch.Size([1259])
# inputs['attention_mask'].shape: torch.Size([1259])
# inputs['images'][0].shape: torch.Size([3, 490, 490])
inputs = { inputs = {
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"), "input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
"token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"), "token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"),
@ -352,15 +415,8 @@ class CogVLMManager(BaseModelWrapper):
"output_hidden_states": True, "output_hidden_states": True,
"return_dict": True "return_dict": True
} }
# inputs['input_ids'].shape: torch.Size([1, 1259])
# inputs['attention_mask'].shape: torch.Size([1, 1259])
# inputs['images'][0][0].shape: torch.Size([3, 490, 490])
# len(inputs['images'][0]): 1
# len(inputs['images'][0][0]): 3
outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids) outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
#print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}")
#print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}")
len_inputs = inputs['input_ids'].shape[1] len_inputs = inputs['input_ids'].shape[1]
outputs_without_prompt = outputs[:, len_inputs:] outputs_without_prompt = outputs[:, len_inputs:]
@ -369,12 +425,22 @@ class CogVLMManager(BaseModelWrapper):
return caption return caption
def get_model_wrapper(model_name: str): def get_model_wrapper(model_name: str):
if "moai" in model_name: match model_name.casefold():
return MoaiManager(model_name) # case x if "moai" in x:
elif "llava" in model_name: # #return MoaiManager(model_name)
return XtunerLlavaModelManager(model_name) # return None
else: case x if "llava" in x:
return CogVLMManager(model_name) return XtunerLlavaModelManager(model_name)
case "thudm/glm-4v-9b":
return CogGLMManager(model_name)
case "thudm/cogvlm2-llama3-chat-19b":
return CogVLMManager(model_name)
case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]:
return CogVLMManager(model_name)
case None:
return CogVLMManager(model_name)
case _:
raise ValueError(f"Model {model_name} not supported")
def get_inputs_dict(inputs): def get_inputs_dict(inputs):
inputs = { inputs = {
@ -518,6 +584,7 @@ if __name__ == "__main__":
argparser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch processing. Does NOT work with COG! (def: 1)") argparser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch processing. Does NOT work with COG! (def: 1)")
argparser.add_argument("--debug", action="store_true", help="Enable debug logging") argparser.add_argument("--debug", action="store_true", help="Enable debug logging")
argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.") argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.")
argparser.add_argument("--dtype", choices=["auto","fp16","bf16","nf4","fp4"], default="auto", help="Data type for inference (def: auto, see docs)")
argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling") argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling")
argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)") argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)")
argparser.add_argument("--top_k", type=int, default=None, help="Top-k, filter k highest probability tokens before sampling") argparser.add_argument("--top_k", type=int, default=None, help="Top-k, filter k highest probability tokens before sampling")
@ -540,7 +607,7 @@ if __name__ == "__main__":
argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.") argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.")
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.") argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.") argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.") argparser.add_argument("--model", type=str, default=None, help="Model to use for captioning.")
argparser.add_argument("--min_pixels", type=int, default=1, help="Minimum total pixel size to caption, under the limit will be skipped") argparser.add_argument("--min_pixels", type=int, default=1, help="Minimum total pixel size to caption, under the limit will be skipped")
args, unknown_args = argparser.parse_known_args() args, unknown_args = argparser.parse_known_args()

View File

@ -8,13 +8,23 @@ It is capable of naming and identifying things with proper nouns and has a large
<a href="https://colab.research.google.com/github/nawnie/EveryDream2trainer/blob/main/CaptionCog.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a> <a href="https://colab.research.google.com/github/nawnie/EveryDream2trainer/blob/main/CaptionCog.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
Both the ([Vicuna-based](https://huggingface.co/THUDM/cogvlm-chat-hf)) and ([Llama3-based](https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B)) models are supported.
Choose these by using one of these two CLI args:
--model THUDM/cogvlm-chat-hf
--model THUDM/cogvlm2-llama3-chat-19B
The script uses the Vicuna model (first) by default if no `--model` arg is specified.
## Llava update ## Llava update
This script now (confusiningly) supports (Xtuner's Llava Llama3 8b v1.1)[https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main]. This script now (confusiningly) supports (Xtuner's Llava Llama3 8b v1.1)[https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main].
To use, add `--model "https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main"` to your command line. To use, add `--model "https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main"` to your command line.
This is a work in progress. So far it seems bad_words do not work. When using Llava, the script will perform some clean-up operations to remove some less-than-useful language from the caption because the bad_words part of the Hugginface Transformers API is not supported by Llava.
## Basics ## Basics

View File

@ -5,7 +5,7 @@ LABEL org.opencontainers.image.licenses="AGPL-3.0-only"
ARG DEBIAN_FRONTEND=noninteractive ARG DEBIAN_FRONTEND=noninteractive
# Don't write .pyc bytecode # Don't write .pyc bytecode
ENV PYTHONDONTWRITEBYTECODE=1 # ENV PYTHONDONTWRITEBYTECODE=1
# Create workspace working directory # Create workspace working directory
RUN mkdir /build RUN mkdir /build
@ -49,7 +49,7 @@ ENV DEBIAN_FRONTEND noninteractive\
ENV PYTHONUNBUFFERED=1 ENV PYTHONUNBUFFERED=1
# Don't write .pyc bytecode # Don't write .pyc bytecode
ENV PYTHONDONTWRITEBYTECODE=1 # ENV PYTHONDONTWRITEBYTECODE=1
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
--mount=type=cache,target=/var/lib/apt,sharing=locked \ --mount=type=cache,target=/var/lib/apt,sharing=locked \
@ -65,7 +65,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
echo "en_US.UTF-8 UTF-8" > /etc/locale.gen echo "en_US.UTF-8 UTF-8" > /etc/locale.gen
# Install runpodctl # Install runpodctl
RUN wget https://github.com/runpod/runpodctl/releases/download/v1.9.0/runpodctl-linux-amd -O runpodctl && \ RUN wget https://github.com/runpod/runpodctl/releases/download/v1.14.3/runpodctl-linux-amd64 -O runpodctl && \
chmod a+x runpodctl && \ chmod a+x runpodctl && \
mv runpodctl /usr/local/bin mv runpodctl /usr/local/bin

View File

@ -1,4 +1,4 @@
diffusers[torch]>=0.21.4 diffusers[torch]>=0.27.2
ninja ninja
numpy numpy
omegaconf==2.2.3 omegaconf==2.2.3

View File

@ -19,4 +19,5 @@ safetensors
prodigyopt prodigyopt
torchsde torchsde
peft==0.9.0 peft==0.9.0
unidecode unidecode
tiktoken

View File

@ -22,4 +22,5 @@ numpy==1.23.5
wandb wandb
colorama colorama
safetensors safetensors
torchsde torchsde
tiktoken

View File

@ -26,6 +26,7 @@ def main(args):
except Exception as e: except Exception as e:
print(f"FAILED: {path}") print(f"FAILED: {path}")
failed.append((path,e)) failed.append((path,e))
if not failed: if not failed:
print("No errors found") print("No errors found")
else: else:
@ -33,7 +34,6 @@ def main(args):
for path, e in failed: for path, e in failed:
print(f"FAILED: {path} {e}") print(f"FAILED: {path} {e}")
if __name__ == '__main__': if __name__ == '__main__':
print("This script checks that all images in a directory are valid.") print("This script checks that all images in a directory are valid.")
print("If any errors occur, they will be printed out at the end.") print("If any errors occur, they will be printed out at the end.")

88
scripts/mt_grid.py Normal file
View File

@ -0,0 +1,88 @@
from diffusers import StableDiffusionPipeline
import torch
from torch.cuda.amp import autocast
import os
import argparse
from PIL import Image, ImageDraw, ImageFont
def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen,
steps: int = 30, batch_size: int = 1):
"""
generates a single sample at a given cfg scale and saves it to disk
"""
with autocast():
images = pipe(prompt,
num_inference_steps=steps,
num_images_per_prompt=batch_size,
guidance_scale=cfg,
generator=gen,
height=height,
width=width,
).images
return images
def generate_simple(prompt, model):
pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda")
images = __generate_sample(pipe, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)
return images[0]
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("--epochs", type=int, default=60)
args = argparser.parse_args()
epochs = args.epochs
path = "/mnt/nvme/mt/val"
model = None
if epochs == 100:
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep100-gs159000"
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep100-gs159000"
elif epochs == 80:
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep80-gs127200"
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep80-gs127200"
elif epochs == 60:
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400"
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep60-gs95400"
else:
raise ValueError("epochs must be 100, 80, or 60")
pipe1 = StableDiffusionPipeline.from_pretrained(model1).to("cuda")
pipe2 = StableDiffusionPipeline.from_pretrained(model2).to("cuda")
for root, dirs, files in os.walk(path):
for file in files:
if file.endswith(".txt") and not file.endswith("file_list.txt"):
txt_path = os.path.join(root, file)
with open(txt_path, "r", encoding="utf-8") as f:
prompt = f.readline()
generated_image1 = __generate_sample(pipe1, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0]
short_prompt = prompt.split(",")[0]
generated_image2 = __generate_sample(pipe2, short_prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0]
print(short_prompt)
gt_path = txt_path.replace(".txt", ".png")
print(f"Loading gt_path {gt_path}")
ground_truth_image = Image.open(gt_path)
ground_truth_image = ground_truth_image.resize((512, 512))
combined_image = Image.new("RGB", (1536, 576), color=(96, 96, 96))
combined_image.paste(ground_truth_image, (0, 0))
combined_image.paste(generated_image1, (512, 0))
combined_image.paste(generated_image2, (1024, 0))
draw = ImageDraw.Draw(combined_image)
font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 18)
draw.text((0, 510), f"epochs={epochs}", font=font)
draw.text((200, 510), "↑ ground truth ↑", font=font)
draw.text((650, 510), "↑ trained&generated full caption↑", font=font)
draw.text((1140, 510), "↑ trained&generated short caption ↑", font=font)
font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 24)
draw.text((100, 536), prompt, font=font)
draw.text((1240, 537), short_prompt, font=font)
output_path = os.path.join("/mnt/nvme/mt", str(epochs), f"{file}_compare.png")
print(f"Saving to {output_path}")
combined_image.save(output_path)

66
scripts/split_val.py Normal file
View File

@ -0,0 +1,66 @@
"""
Copyright [2022] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
## Script to move random files from source_root to destination_root for use as validation split
## chooses (num_pairs_to_move) image/caption pairs from each subfolder in source_root and moves them to destination_root, preserving subfolder structure
## also creates a validation_captions.txt file in destination_root for inference testing later to see how model performs on unseen data
## only works on (png|jpg) + txt pairs. Will break if there is no .txt file or images are other extensions
import os
import shutil
import random
source_root = "/mnt/nvme/mydata_train"
destination_root = "/mnt/nvme/mydata_val"
num_pairs_to_move = 3 # TODO: also do % based moving instead of fixed number
def move_random_file_pairs(source_folder, destination_folder):
with open(os.path.join(destination_folder, "validation_captions.txt"), "w") as f:
for subdir, dirs, files in os.walk(source_folder):
for dir in dirs:
source_subfolder = os.path.join(subdir, dir)
destination_subfolder = os.path.join(destination_folder, dir)
if not os.path.exists(destination_subfolder):
os.makedirs(destination_subfolder)
file_list = [f for f in os.listdir(source_subfolder) if f.endswith((".png")) or f.endswith((".jpg"))]
if len(file_list) >= num_pairs_to_move:
random.shuffle(file_list)
for i in range(num_pairs_to_move):
file_name = file_list[i]
source_file = os.path.join(source_subfolder, file_name)
destination_file = os.path.join(destination_subfolder, file_name)
caption_file_name = os.path.splitext(file_name)[0] + ".txt"
caption_source_file = os.path.join(source_subfolder, caption_file_name)
caption_destination_file = os.path.join(destination_subfolder, caption_file_name)
with open(caption_source_file, "r") as caption_source:
caption = caption_source.readline()
f.write(caption)
f.write("\n")
print(f"Moving {caption_source_file} to {caption_destination_file}")
print(f"Moving {source_file} to {destination_file}")
print(f"Caption: {caption}\n")
shutil.move(source_file, destination_file)
shutil.move(caption_source_file, caption_destination_file)
move_random_file_pairs(source_root, destination_root)

View File

@ -0,0 +1,113 @@
from diffusers import StableDiffusionPipeline
import torch
from torch.cuda.amp import autocast
import os
import argparse
from PIL import Image
def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen,
steps: int = 30, batch_size: int = 1):
"""
generates a single sample at a given cfg scale and saves it to disk
"""
with autocast():
images = pipe(prompt,
num_inference_steps=steps,
num_images_per_prompt=batch_size,
guidance_scale=cfg,
generator=gen,
height=height,
width=width,
).images
return images
def simple():
pipe = StableDiffusionPipeline.from_pretrained("/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400").to("cuda")
images = __generate_sample(pipe, "bicycle", cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)
images[0].save("test.png")
if __name__ == "__main__":
#simple()
argparser = argparse.ArgumentParser()
argparser.add_argument("--prompt_file", type=str, required=False)
argparser.add_argument("--val_data_path", type=str, default="/mnt/nvme/mt/val", required=False)
argparser.add_argument("--models", nargs="+", help="names of models")
args = argparser.parse_args()
args.val_data_path = "/mnt/nvme/mt/val/00b42"
W = 512
H = 512
BATCH_SIZE = 4
print(f"Generating grid image for {len(args.models)} models and {args.prompt_file}")
#print each args.models
print("Models:")
for m in args.models:
print(f" {m}")
# with open(args.prompt_file, "r") as f:
# prompt_master_list = []
# for x, line in enumerate(f):
# prompt_master_list.append(line.strip())
# open the txt files in args.val_data_path
prompt_master_list = {}
for f in os.listdir(args.val_data_path):
if f.endswith(".txt"):
txt_path = os.path.join(args.val_data_path, f)
with open(os.path.join(args.val_data_path, f), "r", encoding="utf-8") as f2:
img_path = os.path.splitext(f)[0] + ".png"
img_path = os.path.join(args.val_data_path, img_path)
prompt_master_list[img_path] = f2.readline().strip()
print(f"Found {len(prompt_master_list)} images in {args.val_data_path}")
print(f"First 10 images: {list(prompt_master_list.values())[:10]}")
print()
num_lines = len(prompt_master_list)
grid_h = (num_lines + 1) * W # num images plus blank for left column labels
grid_w = (1 + len(args.models)) * H # num models plus blank for top row labels
grid_img = Image.new("RGB", (grid_w, grid_h))
#num_iterations = len(prompt_master_list) // BATCH_SIZE + (len(prompt_master_list) % BATCH_SIZE > 0)
chunked_dict_list = []
chunk = {}
for key, value in prompt_master_list.items():
chunk[key] = value
if len(chunk) == BATCH_SIZE:
chunked_dict_list.append(chunk)
chunk = {}
# Append any remaining items if the total number of items is not a multiple of chunk_size
if chunk:
chunked_dict_list.append(chunk)
# Iterate through the chunks
for i, chunk in enumerate(chunked_dict_list):
print(f"Chunk {i + 1}: {chunk}")
exit()
for i_m, model in enumerate(args.models):
for j_p in range(chunked_dict_list):
start_index = j_p * BATCH_SIZE
end_index = (j_p + 1) * BATCH_SIZE
current_prompts = prompt_master_list[start_index:end_index]
print(f"{model}: {current_prompts}")
print()
if True:
pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda")
seed_generator = torch.Generator(pipe.device).manual_seed(555)
images = __generate_sample(pipe, current_prompts, cfg=7.5, height=512, width=512, gen=seed_generator, steps=40, batch_size=BATCH_SIZE)
# paste each image into the grid starting from H,W and incrementing by W
for k, k_img in enumerate(images):
k_img.save(f"tmp/{i_m}_{k}.png")
grid_img.paste(k_img, (W+k*W, H+H*i_m))
# save the grid image
grid_img.save(f"tmp/grid.png")

View File

@ -8,6 +8,7 @@
"data_root": "/mnt/q/training_samples/ff7r/man", "data_root": "/mnt/q/training_samples/ff7r/man",
"disable_amp": false, "disable_amp": false,
"disable_textenc_training": false, "disable_textenc_training": false,
"embedding_perturbation": 0.0,
"flip_p": 0.0, "flip_p": 0.0,
"gpuid": 0, "gpuid": 0,
"gradient_checkpointing": true, "gradient_checkpointing": true,
@ -28,8 +29,7 @@
"save_ckpts_from_n_epochs": 0, "save_ckpts_from_n_epochs": 0,
"save_every_n_epochs": 20, "save_every_n_epochs": 20,
"save_optimizer": false, "save_optimizer": false,
"scale_lr": false, "seed": -1,
"seed": 555,
"timestep_start": 0, "timestep_start": 0,
"timestep_end": 1000, "timestep_end": 1000,
"shuffle_tags": false, "shuffle_tags": false,

View File

@ -733,6 +733,7 @@ def main(args):
try: try:
print() print()
# currently broken on most systems?
#unet = torch.compile(unet, mode="max-autotune") #unet = torch.compile(unet, mode="max-autotune")
#text_encoder = torch.compile(text_encoder, mode="max-autotune") #text_encoder = torch.compile(text_encoder, mode="max-autotune")
#vae = torch.compile(vae, mode="max-autotune") #vae = torch.compile(vae, mode="max-autotune")
@ -833,7 +834,6 @@ def main(args):
logging.info( logging.info(
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_update_interval: {args.ema_update_interval}, ema_device: {args.ema_device}.") f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_update_interval: {args.ema_update_interval}, ema_device: {args.ema_device}.")
ed_optimizer = EveryDreamOptimizer(args, ed_optimizer = EveryDreamOptimizer(args,
optimizer_config, optimizer_config,
text_encoder, text_encoder,
@ -931,7 +931,7 @@ def main(args):
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct" assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate # actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None): def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None, embedding_perturbation=0.0):
with torch.no_grad(): with torch.no_grad():
with autocast(enabled=args.amp): with autocast(enabled=args.amp):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device) pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
@ -968,6 +968,11 @@ def main(args):
else: else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state encoder_hidden_states = encoder_hidden_states.last_hidden_state
# https://arxiv.org/pdf/2405.20494
perturbation_deviation = embedding_perturbation / math.sqrt(encoder_hidden_states.shape[2])
perturbation_delta = torch.randn_like(encoder_hidden_states) * (perturbation_deviation)
encoder_hidden_states = encoder_hidden_states + perturbation_delta
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if noise_scheduler.config.prediction_type == "epsilon": if noise_scheduler.config.prediction_type == "epsilon":
@ -1195,7 +1200,8 @@ def main(args):
batch["tokens"], batch["tokens"],
args.zero_frequency_noise_ratio, args.zero_frequency_noise_ratio,
return_loss=True, return_loss=True,
loss_scale=batch["loss_scale"]) loss_scale=batch["loss_scale"],
embedding_perturbation=args.embedding_perturbation)
del target, model_pred del target, model_pred
@ -1362,6 +1368,7 @@ if __name__ == "__main__":
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables automatic mixed precision (def: False)") argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables automatic mixed precision (def: False)")
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)") argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED") argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
argparser.add_argument("--embedding_perturbation", type=float, default=0.0, help="random perturbation of text embeddings (def: 0.0)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!") argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids") argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)") argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")

View File

@ -26,6 +26,7 @@ pip install prodigyopt
pip install torchsde pip install torchsde
pip install peft>=0.9.0 pip install peft>=0.9.0
pip install unidecode pip install unidecode
pip install tiktoken
python utils/get_yamls.py python utils/get_yamls.py
GOTO :eof GOTO :eof