update caption to work with cog2 and glm-9v, add embedding_perturbation
This commit is contained in:
parent
d96b9cc56e
commit
beec38726a
283
caption_cog.py
283
caption_cog.py
|
@ -47,7 +47,8 @@ except ImportError:
|
|||
|
||||
Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit
|
||||
|
||||
IMAGE_SIZE: int = 490
|
||||
IMAGE_SIZE_COG1: int = 490
|
||||
IMAGE_SIZE_COG2: int = 1344
|
||||
PATCH_SIZE: int = 14
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
@ -70,7 +71,7 @@ def save_params(args, gen_kwargs):
|
|||
with open(save_path, "w") as f:
|
||||
f.write(pretty_print)
|
||||
|
||||
def create_bnb_config(bnb_4bit_compute_dtype="bfloat16",bnb_4bit_quant_type= "fp4"):
|
||||
def create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type= "fp4"):
|
||||
return BitsAndBytesConfig(
|
||||
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
||||
bnb_4bit_quant_type=bnb_4bit_quant_type,
|
||||
|
@ -89,15 +90,26 @@ class BaseModelWrapper:
|
|||
self.model_name = model_name
|
||||
logging.info(f"Loading {model_name}")
|
||||
|
||||
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
|
||||
def load_model(self, dtype: str="auto"):
|
||||
bnb_config = self._maybe_create_bnb_config(dtype)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
quantization_config = bnb_config
|
||||
).to(0)
|
||||
|
||||
self.tokenizer = AutoProcessor.from_pretrained(self.model_name)
|
||||
return self.model, self.tokenizer
|
||||
|
||||
def _maybe_create_bnb_config(self, dtype, auto_bnb=True, auto_bnb_dtype="fp4"):
|
||||
bnb_config = None
|
||||
if dtype == "auto":
|
||||
if auto_bnb:
|
||||
bnb_config = create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type=auto_bnb_dtype)
|
||||
if dtype in ["nf4", "fp4"]:
|
||||
bnb_config = create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type=dtype)
|
||||
return bnb_config
|
||||
|
||||
def get_gen_kwargs(self, args):
|
||||
gen_kwargs = {
|
||||
|
@ -129,47 +141,7 @@ class BaseModelWrapper:
|
|||
else:
|
||||
logging.debug(f"** Sampling enabled")
|
||||
return gen_kwargs
|
||||
|
||||
def caption(prompt, args):
|
||||
return ""
|
||||
|
||||
class XtunerLlavaModelManager(BaseModelWrapper):
|
||||
def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"):
|
||||
self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers"
|
||||
super().__init__(model_name)
|
||||
|
||||
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
|
||||
self.model = LlavaForConditionalGeneration.from_pretrained(
|
||||
#self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
#quantization_config=create_bnb_config()
|
||||
).to(0)
|
||||
|
||||
self.processor = LlavaProcessor.from_pretrained(self.model_name)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
|
||||
print(f"self.tokenizer: {self.tokenizer}")
|
||||
# tokens = self.tokenizer("foo")
|
||||
# print(f"foo tokens test1: {tokens}")
|
||||
return self.model, self.tokenizer
|
||||
|
||||
def get_inputs(self, image: Image.Image, prompt: str):
|
||||
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
|
||||
return inputs
|
||||
|
||||
def _build_conversational_input_ids(self, prompt, starts_with):
|
||||
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
|
||||
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
|
||||
|
||||
def _truncate_to_whole_sentences(self, caption):
|
||||
# model does not stop generating cleanly and cuts off mid sentence
|
||||
caption = caption.split(".")
|
||||
caption = ". ".join(caption[0:-1]) + "."
|
||||
caption = caption.replace("\n","")
|
||||
caption = caption.replace(" "," ")
|
||||
return caption
|
||||
|
||||
|
||||
def _clean_caption(self, caption, args):
|
||||
"""
|
||||
Removes some nonsense Llava adds.
|
||||
|
@ -194,11 +166,52 @@ class XtunerLlavaModelManager(BaseModelWrapper):
|
|||
caption = caption.replace(", who is the main subject of the photo.", ".")
|
||||
caption = caption.replace(", who is the main subject.", ".")
|
||||
caption = caption.replace("who is the main subject.", ".")
|
||||
caption = caption.replace(", who is the central focus of the composition.", ".")
|
||||
caption = caption.replace("who is the central focus of the composition.", ".")
|
||||
caption = self._truncate_to_whole_sentences(caption)
|
||||
|
||||
logging.debug(f"**Llava post-cleaning caption: {caption}")
|
||||
return caption
|
||||
|
||||
def caption(prompt, args):
|
||||
return ""
|
||||
|
||||
class XtunerLlavaModelManager(BaseModelWrapper):
|
||||
def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"):
|
||||
self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers"
|
||||
super().__init__(model_name)
|
||||
logging.info("Loading Xtuner Llava-Llama3 model...")
|
||||
|
||||
def load_model(self, dtype="auto"):
|
||||
bnb_config = self._maybe_create_bnb_config(dtype, auto_bnb=False)
|
||||
self.model = LlavaForConditionalGeneration.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.float16,
|
||||
low_cpu_mem_usage=True,
|
||||
quantization_config=bnb_config
|
||||
).to("cuda")
|
||||
|
||||
self.processor = LlavaProcessor.from_pretrained(self.model_name)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
|
||||
|
||||
return self.model, self.tokenizer
|
||||
|
||||
def get_inputs(self, image: Image.Image, prompt: str):
|
||||
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
|
||||
return inputs
|
||||
|
||||
def _build_conversational_input_ids(self, prompt, starts_with):
|
||||
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
|
||||
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
|
||||
|
||||
def _truncate_to_whole_sentences(self, caption):
|
||||
# model does not stop generating cleanly and cuts off mid sentence
|
||||
caption = caption.split(".")
|
||||
caption = ". ".join(caption[0:-1]) + "."
|
||||
caption = caption.replace("\n","")
|
||||
caption = caption.replace(" "," ")
|
||||
return caption
|
||||
|
||||
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
||||
gen_kwargs = self.get_gen_kwargs(args)
|
||||
|
||||
|
@ -227,59 +240,110 @@ class XtunerLlavaModelManager(BaseModelWrapper):
|
|||
caption = self._clean_caption(caption, args)
|
||||
return caption
|
||||
|
||||
class MoaiManager:
|
||||
# class MoaiManager:
|
||||
# def __init__(self, model_name: str):
|
||||
# self.model_name = model_name
|
||||
# self.moai_model = None
|
||||
# self.moai_processor = None
|
||||
# self.seg_model = None
|
||||
# self.seg_processor = None
|
||||
# self.od_model = None
|
||||
# self.od_processor = None
|
||||
# self.sgg_model = None
|
||||
# self.ocr_model = None
|
||||
# logging.info("Loading Moai model...")
|
||||
|
||||
# def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
|
||||
# moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
|
||||
# = prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
|
||||
# self.moai_model = moai_model
|
||||
# self.moai_processor = moai_processor
|
||||
# self.seg_model = seg_model
|
||||
# self.seg_processor = seg_processor
|
||||
# self.od_model = od_model
|
||||
# self.od_processor = od_processor
|
||||
# self.sgg_model = sgg_model
|
||||
# self.ocr_model = ocr_model
|
||||
|
||||
# return moai_model, moai_processor
|
||||
|
||||
# def get_inputs(self, image: Image.Image, prompt: str):
|
||||
# moai_inputs = self.moai_model.demo_process(image=image,
|
||||
# prompt=prompt,
|
||||
# processor=self.moai_processor,
|
||||
# seg_model=self.seg_model,
|
||||
# seg_processor=self.seg_processor,
|
||||
# od_model=self.od_model,
|
||||
# od_processor=self.od_processor,
|
||||
# sgg_model=self.sgg_model,
|
||||
# ocr_model=self.ocr_model,
|
||||
# device="cuda:0")
|
||||
# return moai_inputs
|
||||
|
||||
class CogGLMManager(BaseModelWrapper):
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
self.moai_model = None
|
||||
self.moai_processor = None
|
||||
self.seg_model = None
|
||||
self.seg_processor = None
|
||||
self.od_model = None
|
||||
self.od_processor = None
|
||||
self.sgg_model = None
|
||||
self.ocr_model = None
|
||||
super().__init__(model_name)
|
||||
if not model_name:
|
||||
self.model_name = "THUDM/cogglm-6b"
|
||||
else:
|
||||
self.model_name = model_name
|
||||
logging.info("Loading CogGLM model...")
|
||||
|
||||
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
|
||||
moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
|
||||
= prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
|
||||
self.moai_model = moai_model
|
||||
self.moai_processor = moai_processor
|
||||
self.seg_model = seg_model
|
||||
self.seg_processor = seg_processor
|
||||
self.od_model = od_model
|
||||
self.od_processor = od_processor
|
||||
self.sgg_model = sgg_model
|
||||
self.ocr_model = ocr_model
|
||||
def load_model(self, dtype: str = "auto"):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
|
||||
bnb_config = None
|
||||
if dtype in ["auto","nf4"]:
|
||||
bnb_config = create_bnb_config()
|
||||
self.model = model = AutoModelForCausalLM.from_pretrained(
|
||||
"THUDM/glm-4v-9b",
|
||||
torch_dtype=torch.bfloat16,
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True,
|
||||
quantization_config=bnb_config
|
||||
).eval()
|
||||
if bnb_config is None:
|
||||
# if BNB is used it is automatically sent to cuda device, otherwise need to move it manually
|
||||
self.model = model.to("cuda")
|
||||
|
||||
return moai_model, moai_processor
|
||||
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
||||
gen_kwargs = self.get_gen_kwargs(args)
|
||||
|
||||
def get_inputs(self, image: Image.Image, prompt: str):
|
||||
moai_inputs = self.moai_model.demo_process(image=image,
|
||||
prompt=prompt,
|
||||
processor=self.moai_processor,
|
||||
seg_model=self.seg_model,
|
||||
seg_processor=self.seg_processor,
|
||||
od_model=self.od_model,
|
||||
od_processor=self.od_processor,
|
||||
sgg_model=self.sgg_model,
|
||||
ocr_model=self.ocr_model,
|
||||
device="cuda:0")
|
||||
return moai_inputs
|
||||
inputs = self.tokenizer.apply_chat_template([{"role": "user", "image": image, "content": prompt}],
|
||||
add_generation_prompt=True, tokenize=True, return_tensors="pt",
|
||||
return_dict=True)
|
||||
inputs.to("cuda")
|
||||
|
||||
# def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any:
|
||||
# with torch.inference_mode():
|
||||
# generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache)
|
||||
# answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0]
|
||||
# return answer
|
||||
outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||
|
||||
len_inputs = inputs['input_ids'].shape[1]
|
||||
outputs_without_prompt = outputs[:, len_inputs:]
|
||||
|
||||
caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True)
|
||||
return caption
|
||||
|
||||
class CogVLMManager(BaseModelWrapper):
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
self.model_name = "THUDM/cogvlm-chat-hf"
|
||||
if not model_name:
|
||||
self.model_name = "THUDM/cogvlm-chat-hf"
|
||||
self.cog_version = 1
|
||||
elif model_name.lower() == "THUDM/cogvlm2-llama3-chat-19b".lower():
|
||||
self.model_name = "THUDM/cogvlm2-llama3-chat-19b"
|
||||
self.cog_version = 2
|
||||
else:
|
||||
self.model_name = model_name
|
||||
self.cog_version = 1
|
||||
patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions
|
||||
logging.info("Loading CogVLM model...")
|
||||
|
||||
def load_model(self):
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
||||
def load_model(self, dtype: str = "auto"):
|
||||
if self.model_name.lower() == "THUDM/cogvlm-chat-hf".lower():
|
||||
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
||||
elif self.model_name.lower() == "THUDM/cogvlm2-llama3-chat-19b".lower():
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
|
||||
self.tokenizer.pad_token_id = 128002 # for Llama 3
|
||||
else:
|
||||
raise ValueError("Unknown model name")
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
|
@ -297,7 +361,7 @@ class CogVLMManager(BaseModelWrapper):
|
|||
starts_with: Optional[str] = None,
|
||||
):
|
||||
# based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py
|
||||
image_size: int = IMAGE_SIZE
|
||||
image_size: int = IMAGE_SIZE_COG2 if self.cog_version == 2 else IMAGE_SIZE_COG1
|
||||
patch_size: int = PATCH_SIZE
|
||||
assert images is None or len(images) <= 1, f"not support multi images by now."
|
||||
history = history or []
|
||||
|
@ -306,9 +370,8 @@ class CogVLMManager(BaseModelWrapper):
|
|||
text += starts_with if starts_with is not None else ""
|
||||
|
||||
input_ids = [self.tokenizer.bos_token_id]
|
||||
token_type_ids = [0]
|
||||
token_type_ids = [0] # LANGUAGE_TOKEN_TYPE
|
||||
if images is not None and len(images) == 1:
|
||||
# vision
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(
|
||||
|
@ -319,7 +382,11 @@ class CogVLMManager(BaseModelWrapper):
|
|||
]
|
||||
)
|
||||
images = [transform(images[0])]
|
||||
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
|
||||
if self.cog_version == 1:
|
||||
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
|
||||
elif self.cog_version == 2:
|
||||
vision_token_num = (image_size // patch_size // 2) * (image_size // patch_size // 2) + 2
|
||||
|
||||
input_ids += [self.tokenizer.pad_token_id] * vision_token_num
|
||||
token_type_ids += [1] * vision_token_num
|
||||
text_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
@ -340,10 +407,6 @@ class CogVLMManager(BaseModelWrapper):
|
|||
|
||||
inputs = self._build_conversation_input_ids(query=prompt, history=history, images=[image], starts_with=args.starts_with)
|
||||
|
||||
# inputs['input_ids'].shape: torch.Size([1259])
|
||||
# inputs['attention_mask'].shape: torch.Size([1259])
|
||||
# inputs['images'][0].shape: torch.Size([3, 490, 490])
|
||||
|
||||
inputs = {
|
||||
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
|
||||
"token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"),
|
||||
|
@ -352,15 +415,8 @@ class CogVLMManager(BaseModelWrapper):
|
|||
"output_hidden_states": True,
|
||||
"return_dict": True
|
||||
}
|
||||
# inputs['input_ids'].shape: torch.Size([1, 1259])
|
||||
# inputs['attention_mask'].shape: torch.Size([1, 1259])
|
||||
# inputs['images'][0][0].shape: torch.Size([3, 490, 490])
|
||||
# len(inputs['images'][0]): 1
|
||||
# len(inputs['images'][0][0]): 3
|
||||
|
||||
outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||
#print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}")
|
||||
#print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}")
|
||||
|
||||
len_inputs = inputs['input_ids'].shape[1]
|
||||
outputs_without_prompt = outputs[:, len_inputs:]
|
||||
|
@ -369,12 +425,22 @@ class CogVLMManager(BaseModelWrapper):
|
|||
return caption
|
||||
|
||||
def get_model_wrapper(model_name: str):
|
||||
if "moai" in model_name:
|
||||
return MoaiManager(model_name)
|
||||
elif "llava" in model_name:
|
||||
return XtunerLlavaModelManager(model_name)
|
||||
else:
|
||||
return CogVLMManager(model_name)
|
||||
match model_name.casefold():
|
||||
# case x if "moai" in x:
|
||||
# #return MoaiManager(model_name)
|
||||
# return None
|
||||
case x if "llava" in x:
|
||||
return XtunerLlavaModelManager(model_name)
|
||||
case "thudm/glm-4v-9b":
|
||||
return CogGLMManager(model_name)
|
||||
case "thudm/cogvlm2-llama3-chat-19b":
|
||||
return CogVLMManager(model_name)
|
||||
case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]:
|
||||
return CogVLMManager(model_name)
|
||||
case None:
|
||||
return CogVLMManager(model_name)
|
||||
case _:
|
||||
raise ValueError(f"Model {model_name} not supported")
|
||||
|
||||
def get_inputs_dict(inputs):
|
||||
inputs = {
|
||||
|
@ -518,6 +584,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch processing. Does NOT work with COG! (def: 1)")
|
||||
argparser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||
argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.")
|
||||
argparser.add_argument("--dtype", choices=["auto","fp16","bf16","nf4","fp4"], default="auto", help="Data type for inference (def: auto, see docs)")
|
||||
argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling")
|
||||
argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)")
|
||||
argparser.add_argument("--top_k", type=int, default=None, help="Top-k, filter k highest probability tokens before sampling")
|
||||
|
@ -540,7 +607,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.")
|
||||
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
|
||||
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
|
||||
argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.")
|
||||
argparser.add_argument("--model", type=str, default=None, help="Model to use for captioning.")
|
||||
argparser.add_argument("--min_pixels", type=int, default=1, help="Minimum total pixel size to caption, under the limit will be skipped")
|
||||
args, unknown_args = argparser.parse_known_args()
|
||||
|
||||
|
|
|
@ -8,13 +8,23 @@ It is capable of naming and identifying things with proper nouns and has a large
|
|||
|
||||
<a href="https://colab.research.google.com/github/nawnie/EveryDream2trainer/blob/main/CaptionCog.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
||||
|
||||
Both the ([Vicuna-based](https://huggingface.co/THUDM/cogvlm-chat-hf)) and ([Llama3-based](https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B)) models are supported.
|
||||
|
||||
Choose these by using one of these two CLI args:
|
||||
|
||||
--model THUDM/cogvlm-chat-hf
|
||||
|
||||
--model THUDM/cogvlm2-llama3-chat-19B
|
||||
|
||||
The script uses the Vicuna model (first) by default if no `--model` arg is specified.
|
||||
|
||||
## Llava update
|
||||
|
||||
This script now (confusiningly) supports (Xtuner's Llava Llama3 8b v1.1)[https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main].
|
||||
|
||||
To use, add `--model "https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main"` to your command line.
|
||||
|
||||
This is a work in progress. So far it seems bad_words do not work.
|
||||
When using Llava, the script will perform some clean-up operations to remove some less-than-useful language from the caption because the bad_words part of the Hugginface Transformers API is not supported by Llava.
|
||||
|
||||
## Basics
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ LABEL org.opencontainers.image.licenses="AGPL-3.0-only"
|
|||
ARG DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Don't write .pyc bytecode
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
# ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
# Create workspace working directory
|
||||
RUN mkdir /build
|
||||
|
@ -49,7 +49,7 @@ ENV DEBIAN_FRONTEND noninteractive\
|
|||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Don't write .pyc bytecode
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
# ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
|
@ -65,7 +65,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
|||
echo "en_US.UTF-8 UTF-8" > /etc/locale.gen
|
||||
|
||||
# Install runpodctl
|
||||
RUN wget https://github.com/runpod/runpodctl/releases/download/v1.9.0/runpodctl-linux-amd -O runpodctl && \
|
||||
RUN wget https://github.com/runpod/runpodctl/releases/download/v1.14.3/runpodctl-linux-amd64 -O runpodctl && \
|
||||
chmod a+x runpodctl && \
|
||||
mv runpodctl /usr/local/bin
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
diffusers[torch]>=0.21.4
|
||||
diffusers[torch]>=0.27.2
|
||||
ninja
|
||||
numpy
|
||||
omegaconf==2.2.3
|
||||
|
|
|
@ -19,4 +19,5 @@ safetensors
|
|||
prodigyopt
|
||||
torchsde
|
||||
peft==0.9.0
|
||||
unidecode
|
||||
unidecode
|
||||
tiktoken
|
|
@ -22,4 +22,5 @@ numpy==1.23.5
|
|||
wandb
|
||||
colorama
|
||||
safetensors
|
||||
torchsde
|
||||
torchsde
|
||||
tiktoken
|
|
@ -26,6 +26,7 @@ def main(args):
|
|||
except Exception as e:
|
||||
print(f"FAILED: {path}")
|
||||
failed.append((path,e))
|
||||
|
||||
if not failed:
|
||||
print("No errors found")
|
||||
else:
|
||||
|
@ -33,7 +34,6 @@ def main(args):
|
|||
for path, e in failed:
|
||||
print(f"FAILED: {path} {e}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("This script checks that all images in a directory are valid.")
|
||||
print("If any errors occur, they will be printed out at the end.")
|
||||
|
|
|
@ -0,0 +1,88 @@
|
|||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
import os
|
||||
|
||||
import argparse
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen,
|
||||
steps: int = 30, batch_size: int = 1):
|
||||
"""
|
||||
generates a single sample at a given cfg scale and saves it to disk
|
||||
"""
|
||||
with autocast():
|
||||
images = pipe(prompt,
|
||||
num_inference_steps=steps,
|
||||
num_images_per_prompt=batch_size,
|
||||
guidance_scale=cfg,
|
||||
generator=gen,
|
||||
height=height,
|
||||
width=width,
|
||||
).images
|
||||
|
||||
return images
|
||||
|
||||
def generate_simple(prompt, model):
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda")
|
||||
images = __generate_sample(pipe, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)
|
||||
return images[0]
|
||||
|
||||
if __name__ == "__main__":
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--epochs", type=int, default=60)
|
||||
args = argparser.parse_args()
|
||||
epochs = args.epochs
|
||||
|
||||
path = "/mnt/nvme/mt/val"
|
||||
|
||||
model = None
|
||||
if epochs == 100:
|
||||
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep100-gs159000"
|
||||
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep100-gs159000"
|
||||
elif epochs == 80:
|
||||
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep80-gs127200"
|
||||
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep80-gs127200"
|
||||
elif epochs == 60:
|
||||
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400"
|
||||
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep60-gs95400"
|
||||
else:
|
||||
raise ValueError("epochs must be 100, 80, or 60")
|
||||
|
||||
pipe1 = StableDiffusionPipeline.from_pretrained(model1).to("cuda")
|
||||
pipe2 = StableDiffusionPipeline.from_pretrained(model2).to("cuda")
|
||||
|
||||
for root, dirs, files in os.walk(path):
|
||||
for file in files:
|
||||
if file.endswith(".txt") and not file.endswith("file_list.txt"):
|
||||
txt_path = os.path.join(root, file)
|
||||
with open(txt_path, "r", encoding="utf-8") as f:
|
||||
prompt = f.readline()
|
||||
generated_image1 = __generate_sample(pipe1, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0]
|
||||
short_prompt = prompt.split(",")[0]
|
||||
generated_image2 = __generate_sample(pipe2, short_prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0]
|
||||
print(short_prompt)
|
||||
gt_path = txt_path.replace(".txt", ".png")
|
||||
print(f"Loading gt_path {gt_path}")
|
||||
|
||||
ground_truth_image = Image.open(gt_path)
|
||||
ground_truth_image = ground_truth_image.resize((512, 512))
|
||||
|
||||
combined_image = Image.new("RGB", (1536, 576), color=(96, 96, 96))
|
||||
combined_image.paste(ground_truth_image, (0, 0))
|
||||
combined_image.paste(generated_image1, (512, 0))
|
||||
combined_image.paste(generated_image2, (1024, 0))
|
||||
|
||||
draw = ImageDraw.Draw(combined_image)
|
||||
font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 18)
|
||||
draw.text((0, 510), f"epochs={epochs}", font=font)
|
||||
draw.text((200, 510), "↑ ground truth ↑", font=font)
|
||||
draw.text((650, 510), "↑ trained&generated full caption↑", font=font)
|
||||
draw.text((1140, 510), "↑ trained&generated short caption ↑", font=font)
|
||||
font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 24)
|
||||
draw.text((100, 536), prompt, font=font)
|
||||
draw.text((1240, 537), short_prompt, font=font)
|
||||
|
||||
output_path = os.path.join("/mnt/nvme/mt", str(epochs), f"{file}_compare.png")
|
||||
print(f"Saving to {output_path}")
|
||||
combined_image.save(output_path)
|
|
@ -0,0 +1,66 @@
|
|||
"""
|
||||
Copyright [2022] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
## Script to move random files from source_root to destination_root for use as validation split
|
||||
## chooses (num_pairs_to_move) image/caption pairs from each subfolder in source_root and moves them to destination_root, preserving subfolder structure
|
||||
## also creates a validation_captions.txt file in destination_root for inference testing later to see how model performs on unseen data
|
||||
## only works on (png|jpg) + txt pairs. Will break if there is no .txt file or images are other extensions
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import random
|
||||
|
||||
source_root = "/mnt/nvme/mydata_train"
|
||||
destination_root = "/mnt/nvme/mydata_val"
|
||||
|
||||
num_pairs_to_move = 3 # TODO: also do % based moving instead of fixed number
|
||||
|
||||
def move_random_file_pairs(source_folder, destination_folder):
|
||||
with open(os.path.join(destination_folder, "validation_captions.txt"), "w") as f:
|
||||
for subdir, dirs, files in os.walk(source_folder):
|
||||
for dir in dirs:
|
||||
source_subfolder = os.path.join(subdir, dir)
|
||||
destination_subfolder = os.path.join(destination_folder, dir)
|
||||
|
||||
if not os.path.exists(destination_subfolder):
|
||||
os.makedirs(destination_subfolder)
|
||||
|
||||
file_list = [f for f in os.listdir(source_subfolder) if f.endswith((".png")) or f.endswith((".jpg"))]
|
||||
|
||||
if len(file_list) >= num_pairs_to_move:
|
||||
random.shuffle(file_list)
|
||||
|
||||
for i in range(num_pairs_to_move):
|
||||
file_name = file_list[i]
|
||||
source_file = os.path.join(source_subfolder, file_name)
|
||||
destination_file = os.path.join(destination_subfolder, file_name)
|
||||
|
||||
caption_file_name = os.path.splitext(file_name)[0] + ".txt"
|
||||
caption_source_file = os.path.join(source_subfolder, caption_file_name)
|
||||
caption_destination_file = os.path.join(destination_subfolder, caption_file_name)
|
||||
|
||||
with open(caption_source_file, "r") as caption_source:
|
||||
caption = caption_source.readline()
|
||||
f.write(caption)
|
||||
f.write("\n")
|
||||
print(f"Moving {caption_source_file} to {caption_destination_file}")
|
||||
print(f"Moving {source_file} to {destination_file}")
|
||||
print(f"Caption: {caption}\n")
|
||||
shutil.move(source_file, destination_file)
|
||||
shutil.move(caption_source_file, caption_destination_file)
|
||||
|
||||
|
||||
move_random_file_pairs(source_root, destination_root)
|
|
@ -0,0 +1,113 @@
|
|||
from diffusers import StableDiffusionPipeline
|
||||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
import os
|
||||
|
||||
import argparse
|
||||
from PIL import Image
|
||||
|
||||
def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen,
|
||||
steps: int = 30, batch_size: int = 1):
|
||||
"""
|
||||
generates a single sample at a given cfg scale and saves it to disk
|
||||
"""
|
||||
with autocast():
|
||||
images = pipe(prompt,
|
||||
num_inference_steps=steps,
|
||||
num_images_per_prompt=batch_size,
|
||||
guidance_scale=cfg,
|
||||
generator=gen,
|
||||
height=height,
|
||||
width=width,
|
||||
).images
|
||||
|
||||
return images
|
||||
|
||||
def simple():
|
||||
pipe = StableDiffusionPipeline.from_pretrained("/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400").to("cuda")
|
||||
images = __generate_sample(pipe, "bicycle", cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)
|
||||
images[0].save("test.png")
|
||||
|
||||
if __name__ == "__main__":
|
||||
#simple()
|
||||
argparser = argparse.ArgumentParser()
|
||||
argparser.add_argument("--prompt_file", type=str, required=False)
|
||||
argparser.add_argument("--val_data_path", type=str, default="/mnt/nvme/mt/val", required=False)
|
||||
argparser.add_argument("--models", nargs="+", help="names of models")
|
||||
args = argparser.parse_args()
|
||||
args.val_data_path = "/mnt/nvme/mt/val/00b42"
|
||||
|
||||
W = 512
|
||||
H = 512
|
||||
BATCH_SIZE = 4
|
||||
|
||||
print(f"Generating grid image for {len(args.models)} models and {args.prompt_file}")
|
||||
#print each args.models
|
||||
print("Models:")
|
||||
for m in args.models:
|
||||
print(f" {m}")
|
||||
|
||||
# with open(args.prompt_file, "r") as f:
|
||||
# prompt_master_list = []
|
||||
# for x, line in enumerate(f):
|
||||
# prompt_master_list.append(line.strip())
|
||||
|
||||
# open the txt files in args.val_data_path
|
||||
prompt_master_list = {}
|
||||
for f in os.listdir(args.val_data_path):
|
||||
if f.endswith(".txt"):
|
||||
txt_path = os.path.join(args.val_data_path, f)
|
||||
with open(os.path.join(args.val_data_path, f), "r", encoding="utf-8") as f2:
|
||||
img_path = os.path.splitext(f)[0] + ".png"
|
||||
img_path = os.path.join(args.val_data_path, img_path)
|
||||
prompt_master_list[img_path] = f2.readline().strip()
|
||||
|
||||
print(f"Found {len(prompt_master_list)} images in {args.val_data_path}")
|
||||
print(f"First 10 images: {list(prompt_master_list.values())[:10]}")
|
||||
print()
|
||||
|
||||
num_lines = len(prompt_master_list)
|
||||
grid_h = (num_lines + 1) * W # num images plus blank for left column labels
|
||||
grid_w = (1 + len(args.models)) * H # num models plus blank for top row labels
|
||||
grid_img = Image.new("RGB", (grid_w, grid_h))
|
||||
|
||||
#num_iterations = len(prompt_master_list) // BATCH_SIZE + (len(prompt_master_list) % BATCH_SIZE > 0)
|
||||
|
||||
chunked_dict_list = []
|
||||
chunk = {}
|
||||
for key, value in prompt_master_list.items():
|
||||
chunk[key] = value
|
||||
if len(chunk) == BATCH_SIZE:
|
||||
chunked_dict_list.append(chunk)
|
||||
chunk = {}
|
||||
|
||||
# Append any remaining items if the total number of items is not a multiple of chunk_size
|
||||
if chunk:
|
||||
chunked_dict_list.append(chunk)
|
||||
|
||||
# Iterate through the chunks
|
||||
for i, chunk in enumerate(chunked_dict_list):
|
||||
print(f"Chunk {i + 1}: {chunk}")
|
||||
exit()
|
||||
|
||||
for i_m, model in enumerate(args.models):
|
||||
for j_p in range(chunked_dict_list):
|
||||
start_index = j_p * BATCH_SIZE
|
||||
end_index = (j_p + 1) * BATCH_SIZE
|
||||
current_prompts = prompt_master_list[start_index:end_index]
|
||||
|
||||
print(f"{model}: {current_prompts}")
|
||||
print()
|
||||
|
||||
if True:
|
||||
pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda")
|
||||
seed_generator = torch.Generator(pipe.device).manual_seed(555)
|
||||
images = __generate_sample(pipe, current_prompts, cfg=7.5, height=512, width=512, gen=seed_generator, steps=40, batch_size=BATCH_SIZE)
|
||||
# paste each image into the grid starting from H,W and incrementing by W
|
||||
for k, k_img in enumerate(images):
|
||||
k_img.save(f"tmp/{i_m}_{k}.png")
|
||||
grid_img.paste(k_img, (W+k*W, H+H*i_m))
|
||||
# save the grid image
|
||||
grid_img.save(f"tmp/grid.png")
|
||||
|
||||
|
|
@ -8,6 +8,7 @@
|
|||
"data_root": "/mnt/q/training_samples/ff7r/man",
|
||||
"disable_amp": false,
|
||||
"disable_textenc_training": false,
|
||||
"embedding_perturbation": 0.0,
|
||||
"flip_p": 0.0,
|
||||
"gpuid": 0,
|
||||
"gradient_checkpointing": true,
|
||||
|
@ -28,8 +29,7 @@
|
|||
"save_ckpts_from_n_epochs": 0,
|
||||
"save_every_n_epochs": 20,
|
||||
"save_optimizer": false,
|
||||
"scale_lr": false,
|
||||
"seed": 555,
|
||||
"seed": -1,
|
||||
"timestep_start": 0,
|
||||
"timestep_end": 1000,
|
||||
"shuffle_tags": false,
|
||||
|
|
13
train.py
13
train.py
|
@ -733,6 +733,7 @@ def main(args):
|
|||
|
||||
try:
|
||||
print()
|
||||
# currently broken on most systems?
|
||||
#unet = torch.compile(unet, mode="max-autotune")
|
||||
#text_encoder = torch.compile(text_encoder, mode="max-autotune")
|
||||
#vae = torch.compile(vae, mode="max-autotune")
|
||||
|
@ -833,7 +834,6 @@ def main(args):
|
|||
logging.info(
|
||||
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_update_interval: {args.ema_update_interval}, ema_device: {args.ema_device}.")
|
||||
|
||||
|
||||
ed_optimizer = EveryDreamOptimizer(args,
|
||||
optimizer_config,
|
||||
text_encoder,
|
||||
|
@ -931,7 +931,7 @@ def main(args):
|
|||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||
|
||||
# actual prediction function - shared between train and validate
|
||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None):
|
||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None, embedding_perturbation=0.0):
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=args.amp):
|
||||
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
|
@ -968,6 +968,11 @@ def main(args):
|
|||
else:
|
||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||
|
||||
# https://arxiv.org/pdf/2405.20494
|
||||
perturbation_deviation = embedding_perturbation / math.sqrt(encoder_hidden_states.shape[2])
|
||||
perturbation_delta = torch.randn_like(encoder_hidden_states) * (perturbation_deviation)
|
||||
encoder_hidden_states = encoder_hidden_states + perturbation_delta
|
||||
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
|
@ -1195,7 +1200,8 @@ def main(args):
|
|||
batch["tokens"],
|
||||
args.zero_frequency_noise_ratio,
|
||||
return_loss=True,
|
||||
loss_scale=batch["loss_scale"])
|
||||
loss_scale=batch["loss_scale"],
|
||||
embedding_perturbation=args.embedding_perturbation)
|
||||
|
||||
del target, model_pred
|
||||
|
||||
|
@ -1362,6 +1368,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables automatic mixed precision (def: False)")
|
||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--embedding_perturbation", type=float, default=0.0, help="random perturbation of text embeddings (def: 0.0)")
|
||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids")
|
||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||
|
|
|
@ -26,6 +26,7 @@ pip install prodigyopt
|
|||
pip install torchsde
|
||||
pip install peft>=0.9.0
|
||||
pip install unidecode
|
||||
pip install tiktoken
|
||||
python utils/get_yamls.py
|
||||
GOTO :eof
|
||||
|
||||
|
|
Loading…
Reference in New Issue