update caption to work with cog2 and glm-9v, add embedding_perturbation
This commit is contained in:
parent
d96b9cc56e
commit
beec38726a
267
caption_cog.py
267
caption_cog.py
|
@ -47,7 +47,8 @@ except ImportError:
|
||||||
|
|
||||||
Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit
|
Image.MAX_IMAGE_PIXELS = 715827880*4 # expand the size limit
|
||||||
|
|
||||||
IMAGE_SIZE: int = 490
|
IMAGE_SIZE_COG1: int = 490
|
||||||
|
IMAGE_SIZE_COG2: int = 1344
|
||||||
PATCH_SIZE: int = 14
|
PATCH_SIZE: int = 14
|
||||||
|
|
||||||
torch.backends.cuda.matmul.allow_tf32 = True
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
@ -89,16 +90,27 @@ class BaseModelWrapper:
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
logging.info(f"Loading {model_name}")
|
logging.info(f"Loading {model_name}")
|
||||||
|
|
||||||
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
|
def load_model(self, dtype: str="auto"):
|
||||||
|
bnb_config = self._maybe_create_bnb_config(dtype)
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
|
quantization_config = bnb_config
|
||||||
).to(0)
|
).to(0)
|
||||||
|
|
||||||
self.tokenizer = AutoProcessor.from_pretrained(self.model_name)
|
self.tokenizer = AutoProcessor.from_pretrained(self.model_name)
|
||||||
return self.model, self.tokenizer
|
return self.model, self.tokenizer
|
||||||
|
|
||||||
|
def _maybe_create_bnb_config(self, dtype, auto_bnb=True, auto_bnb_dtype="fp4"):
|
||||||
|
bnb_config = None
|
||||||
|
if dtype == "auto":
|
||||||
|
if auto_bnb:
|
||||||
|
bnb_config = create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type=auto_bnb_dtype)
|
||||||
|
if dtype in ["nf4", "fp4"]:
|
||||||
|
bnb_config = create_bnb_config(bnb_4bit_compute_dtype="bfloat16", bnb_4bit_quant_type=dtype)
|
||||||
|
return bnb_config
|
||||||
|
|
||||||
def get_gen_kwargs(self, args):
|
def get_gen_kwargs(self, args):
|
||||||
gen_kwargs = {
|
gen_kwargs = {
|
||||||
"max_length": args.max_length,
|
"max_length": args.max_length,
|
||||||
|
@ -130,46 +142,6 @@ class BaseModelWrapper:
|
||||||
logging.debug(f"** Sampling enabled")
|
logging.debug(f"** Sampling enabled")
|
||||||
return gen_kwargs
|
return gen_kwargs
|
||||||
|
|
||||||
def caption(prompt, args):
|
|
||||||
return ""
|
|
||||||
|
|
||||||
class XtunerLlavaModelManager(BaseModelWrapper):
|
|
||||||
def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"):
|
|
||||||
self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers"
|
|
||||||
super().__init__(model_name)
|
|
||||||
|
|
||||||
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
|
|
||||||
self.model = LlavaForConditionalGeneration.from_pretrained(
|
|
||||||
#self.model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
self.model_name,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
#quantization_config=create_bnb_config()
|
|
||||||
).to(0)
|
|
||||||
|
|
||||||
self.processor = LlavaProcessor.from_pretrained(self.model_name)
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
|
|
||||||
print(f"self.tokenizer: {self.tokenizer}")
|
|
||||||
# tokens = self.tokenizer("foo")
|
|
||||||
# print(f"foo tokens test1: {tokens}")
|
|
||||||
return self.model, self.tokenizer
|
|
||||||
|
|
||||||
def get_inputs(self, image: Image.Image, prompt: str):
|
|
||||||
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
def _build_conversational_input_ids(self, prompt, starts_with):
|
|
||||||
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
|
|
||||||
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
|
|
||||||
|
|
||||||
def _truncate_to_whole_sentences(self, caption):
|
|
||||||
# model does not stop generating cleanly and cuts off mid sentence
|
|
||||||
caption = caption.split(".")
|
|
||||||
caption = ". ".join(caption[0:-1]) + "."
|
|
||||||
caption = caption.replace("\n","")
|
|
||||||
caption = caption.replace(" "," ")
|
|
||||||
return caption
|
|
||||||
|
|
||||||
def _clean_caption(self, caption, args):
|
def _clean_caption(self, caption, args):
|
||||||
"""
|
"""
|
||||||
Removes some nonsense Llava adds.
|
Removes some nonsense Llava adds.
|
||||||
|
@ -194,11 +166,52 @@ class XtunerLlavaModelManager(BaseModelWrapper):
|
||||||
caption = caption.replace(", who is the main subject of the photo.", ".")
|
caption = caption.replace(", who is the main subject of the photo.", ".")
|
||||||
caption = caption.replace(", who is the main subject.", ".")
|
caption = caption.replace(", who is the main subject.", ".")
|
||||||
caption = caption.replace("who is the main subject.", ".")
|
caption = caption.replace("who is the main subject.", ".")
|
||||||
|
caption = caption.replace(", who is the central focus of the composition.", ".")
|
||||||
|
caption = caption.replace("who is the central focus of the composition.", ".")
|
||||||
caption = self._truncate_to_whole_sentences(caption)
|
caption = self._truncate_to_whole_sentences(caption)
|
||||||
|
|
||||||
logging.debug(f"**Llava post-cleaning caption: {caption}")
|
logging.debug(f"**Llava post-cleaning caption: {caption}")
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
|
def caption(prompt, args):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
class XtunerLlavaModelManager(BaseModelWrapper):
|
||||||
|
def __init__(self, model_name: str="xtuner/llava-llama-3-8b-v1_1-transformers"):
|
||||||
|
self.model_name = "xtuner/llava-llama-3-8b-v1_1-transformers"
|
||||||
|
super().__init__(model_name)
|
||||||
|
logging.info("Loading Xtuner Llava-Llama3 model...")
|
||||||
|
|
||||||
|
def load_model(self, dtype="auto"):
|
||||||
|
bnb_config = self._maybe_create_bnb_config(dtype, auto_bnb=False)
|
||||||
|
self.model = LlavaForConditionalGeneration.from_pretrained(
|
||||||
|
self.model_name,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
quantization_config=bnb_config
|
||||||
|
).to("cuda")
|
||||||
|
|
||||||
|
self.processor = LlavaProcessor.from_pretrained(self.model_name)
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained("xtuner/llava-llama-3-8b-v1_1-transformers")
|
||||||
|
|
||||||
|
return self.model, self.tokenizer
|
||||||
|
|
||||||
|
def get_inputs(self, image: Image.Image, prompt: str):
|
||||||
|
inputs = self.processor(prompt, image, return_tensors='pt').to(0, torch.float16)
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
def _build_conversational_input_ids(self, prompt, starts_with):
|
||||||
|
return (f"<|start_header_id|>user<|end_header_id|>\n\n<image>\n{prompt}<|eot_id|>"
|
||||||
|
f"<|start_header_id|>assistant<|end_header_id|>\n\n{starts_with}")
|
||||||
|
|
||||||
|
def _truncate_to_whole_sentences(self, caption):
|
||||||
|
# model does not stop generating cleanly and cuts off mid sentence
|
||||||
|
caption = caption.split(".")
|
||||||
|
caption = ". ".join(caption[0:-1]) + "."
|
||||||
|
caption = caption.replace("\n","")
|
||||||
|
caption = caption.replace(" "," ")
|
||||||
|
return caption
|
||||||
|
|
||||||
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
||||||
gen_kwargs = self.get_gen_kwargs(args)
|
gen_kwargs = self.get_gen_kwargs(args)
|
||||||
|
|
||||||
|
@ -227,59 +240,110 @@ class XtunerLlavaModelManager(BaseModelWrapper):
|
||||||
caption = self._clean_caption(caption, args)
|
caption = self._clean_caption(caption, args)
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
class MoaiManager:
|
# class MoaiManager:
|
||||||
|
# def __init__(self, model_name: str):
|
||||||
|
# self.model_name = model_name
|
||||||
|
# self.moai_model = None
|
||||||
|
# self.moai_processor = None
|
||||||
|
# self.seg_model = None
|
||||||
|
# self.seg_processor = None
|
||||||
|
# self.od_model = None
|
||||||
|
# self.od_processor = None
|
||||||
|
# self.sgg_model = None
|
||||||
|
# self.ocr_model = None
|
||||||
|
# logging.info("Loading Moai model...")
|
||||||
|
|
||||||
|
# def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
|
||||||
|
# moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
|
||||||
|
# = prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
|
||||||
|
# self.moai_model = moai_model
|
||||||
|
# self.moai_processor = moai_processor
|
||||||
|
# self.seg_model = seg_model
|
||||||
|
# self.seg_processor = seg_processor
|
||||||
|
# self.od_model = od_model
|
||||||
|
# self.od_processor = od_processor
|
||||||
|
# self.sgg_model = sgg_model
|
||||||
|
# self.ocr_model = ocr_model
|
||||||
|
|
||||||
|
# return moai_model, moai_processor
|
||||||
|
|
||||||
|
# def get_inputs(self, image: Image.Image, prompt: str):
|
||||||
|
# moai_inputs = self.moai_model.demo_process(image=image,
|
||||||
|
# prompt=prompt,
|
||||||
|
# processor=self.moai_processor,
|
||||||
|
# seg_model=self.seg_model,
|
||||||
|
# seg_processor=self.seg_processor,
|
||||||
|
# od_model=self.od_model,
|
||||||
|
# od_processor=self.od_processor,
|
||||||
|
# sgg_model=self.sgg_model,
|
||||||
|
# ocr_model=self.ocr_model,
|
||||||
|
# device="cuda:0")
|
||||||
|
# return moai_inputs
|
||||||
|
|
||||||
|
class CogGLMManager(BaseModelWrapper):
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str):
|
||||||
|
super().__init__(model_name)
|
||||||
|
if not model_name:
|
||||||
|
self.model_name = "THUDM/cogglm-6b"
|
||||||
|
else:
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.moai_model = None
|
logging.info("Loading CogGLM model...")
|
||||||
self.moai_processor = None
|
|
||||||
self.seg_model = None
|
|
||||||
self.seg_processor = None
|
|
||||||
self.od_model = None
|
|
||||||
self.od_processor = None
|
|
||||||
self.sgg_model = None
|
|
||||||
self.ocr_model = None
|
|
||||||
|
|
||||||
def load_model(self, bits: int=4, grad_ckpt: bool=False, lora: bool=False, dtype: str="fp16"):
|
def load_model(self, dtype: str = "auto"):
|
||||||
moai_model, moai_processor, seg_model, seg_processor, od_model, od_processor, sgg_model, ocr_model \
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
|
||||||
= prepare_moai(moai_path=self.model_name, bits=bits, grad_ckpt=grad_ckpt, lora=lora, dtype=dtype)
|
bnb_config = None
|
||||||
self.moai_model = moai_model
|
if dtype in ["auto","nf4"]:
|
||||||
self.moai_processor = moai_processor
|
bnb_config = create_bnb_config()
|
||||||
self.seg_model = seg_model
|
self.model = model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.seg_processor = seg_processor
|
"THUDM/glm-4v-9b",
|
||||||
self.od_model = od_model
|
torch_dtype=torch.bfloat16,
|
||||||
self.od_processor = od_processor
|
low_cpu_mem_usage=True,
|
||||||
self.sgg_model = sgg_model
|
trust_remote_code=True,
|
||||||
self.ocr_model = ocr_model
|
quantization_config=bnb_config
|
||||||
|
).eval()
|
||||||
|
if bnb_config is None:
|
||||||
|
# if BNB is used it is automatically sent to cuda device, otherwise need to move it manually
|
||||||
|
self.model = model.to("cuda")
|
||||||
|
|
||||||
return moai_model, moai_processor
|
def caption(self, prompt, image, args, force_words_ids, bad_words_ids, history=[]):
|
||||||
|
gen_kwargs = self.get_gen_kwargs(args)
|
||||||
|
|
||||||
def get_inputs(self, image: Image.Image, prompt: str):
|
inputs = self.tokenizer.apply_chat_template([{"role": "user", "image": image, "content": prompt}],
|
||||||
moai_inputs = self.moai_model.demo_process(image=image,
|
add_generation_prompt=True, tokenize=True, return_tensors="pt",
|
||||||
prompt=prompt,
|
return_dict=True)
|
||||||
processor=self.moai_processor,
|
inputs.to("cuda")
|
||||||
seg_model=self.seg_model,
|
|
||||||
seg_processor=self.seg_processor,
|
|
||||||
od_model=self.od_model,
|
|
||||||
od_processor=self.od_processor,
|
|
||||||
sgg_model=self.sgg_model,
|
|
||||||
ocr_model=self.ocr_model,
|
|
||||||
device="cuda:0")
|
|
||||||
return moai_inputs
|
|
||||||
|
|
||||||
# def __call__(self, moai_inputs, do_sample=True, temperature=0.9, top_p=0.95, max_new_tokens=256, use_cache=True) -> Any:
|
outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||||
# with torch.inference_mode():
|
|
||||||
# generate_ids = self.moai_model.generate(**moai_inputs, do_sample=do_sample, temperature=temperature, top_p=top_p, max_new_tokens=max_new_tokens, use_cache=use_cache)
|
len_inputs = inputs['input_ids'].shape[1]
|
||||||
# answer = self.moai_processor.batch_decode(generate_ids, skip_special_tokens=True)[0].split("[U")[0]
|
outputs_without_prompt = outputs[:, len_inputs:]
|
||||||
# return answer
|
|
||||||
|
caption = self.tokenizer.decode(outputs_without_prompt[0], skip_special_tokens=True)
|
||||||
|
return caption
|
||||||
|
|
||||||
class CogVLMManager(BaseModelWrapper):
|
class CogVLMManager(BaseModelWrapper):
|
||||||
def __init__(self, model_name: str):
|
def __init__(self, model_name: str):
|
||||||
super().__init__(model_name)
|
super().__init__(model_name)
|
||||||
|
if not model_name:
|
||||||
self.model_name = "THUDM/cogvlm-chat-hf"
|
self.model_name = "THUDM/cogvlm-chat-hf"
|
||||||
|
self.cog_version = 1
|
||||||
|
elif model_name.lower() == "THUDM/cogvlm2-llama3-chat-19b".lower():
|
||||||
|
self.model_name = "THUDM/cogvlm2-llama3-chat-19b"
|
||||||
|
self.cog_version = 2
|
||||||
|
else:
|
||||||
|
self.model_name = model_name
|
||||||
|
self.cog_version = 1
|
||||||
patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions
|
patch_cog() # fixes inv_freq key error with cogvlm, quantization, and newer transformers revisions
|
||||||
|
logging.info("Loading CogVLM model...")
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self, dtype: str = "auto"):
|
||||||
|
if self.model_name.lower() == "THUDM/cogvlm-chat-hf".lower():
|
||||||
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
self.tokenizer = LlamaTokenizer.from_pretrained("lmsys/vicuna-7b-v1.5")
|
||||||
|
elif self.model_name.lower() == "THUDM/cogvlm2-llama3-chat-19b".lower():
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name, trust_remote_code=True)
|
||||||
|
self.tokenizer.pad_token_id = 128002 # for Llama 3
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown model name")
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(
|
self.model = AutoModelForCausalLM.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
torch_dtype=torch.bfloat16,
|
torch_dtype=torch.bfloat16,
|
||||||
|
@ -297,7 +361,7 @@ class CogVLMManager(BaseModelWrapper):
|
||||||
starts_with: Optional[str] = None,
|
starts_with: Optional[str] = None,
|
||||||
):
|
):
|
||||||
# based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py
|
# based on https://huggingface.co/THUDM/cogvlm-chat-hf/blob/main/modeling_cogvlm.py
|
||||||
image_size: int = IMAGE_SIZE
|
image_size: int = IMAGE_SIZE_COG2 if self.cog_version == 2 else IMAGE_SIZE_COG1
|
||||||
patch_size: int = PATCH_SIZE
|
patch_size: int = PATCH_SIZE
|
||||||
assert images is None or len(images) <= 1, f"not support multi images by now."
|
assert images is None or len(images) <= 1, f"not support multi images by now."
|
||||||
history = history or []
|
history = history or []
|
||||||
|
@ -306,9 +370,8 @@ class CogVLMManager(BaseModelWrapper):
|
||||||
text += starts_with if starts_with is not None else ""
|
text += starts_with if starts_with is not None else ""
|
||||||
|
|
||||||
input_ids = [self.tokenizer.bos_token_id]
|
input_ids = [self.tokenizer.bos_token_id]
|
||||||
token_type_ids = [0]
|
token_type_ids = [0] # LANGUAGE_TOKEN_TYPE
|
||||||
if images is not None and len(images) == 1:
|
if images is not None and len(images) == 1:
|
||||||
# vision
|
|
||||||
transform = transforms.Compose(
|
transform = transforms.Compose(
|
||||||
[
|
[
|
||||||
transforms.Resize(
|
transforms.Resize(
|
||||||
|
@ -319,7 +382,11 @@ class CogVLMManager(BaseModelWrapper):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
images = [transform(images[0])]
|
images = [transform(images[0])]
|
||||||
|
if self.cog_version == 1:
|
||||||
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
|
vision_token_num = (image_size // patch_size) * (image_size // patch_size) + 2
|
||||||
|
elif self.cog_version == 2:
|
||||||
|
vision_token_num = (image_size // patch_size // 2) * (image_size // patch_size // 2) + 2
|
||||||
|
|
||||||
input_ids += [self.tokenizer.pad_token_id] * vision_token_num
|
input_ids += [self.tokenizer.pad_token_id] * vision_token_num
|
||||||
token_type_ids += [1] * vision_token_num
|
token_type_ids += [1] * vision_token_num
|
||||||
text_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
text_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
@ -340,10 +407,6 @@ class CogVLMManager(BaseModelWrapper):
|
||||||
|
|
||||||
inputs = self._build_conversation_input_ids(query=prompt, history=history, images=[image], starts_with=args.starts_with)
|
inputs = self._build_conversation_input_ids(query=prompt, history=history, images=[image], starts_with=args.starts_with)
|
||||||
|
|
||||||
# inputs['input_ids'].shape: torch.Size([1259])
|
|
||||||
# inputs['attention_mask'].shape: torch.Size([1259])
|
|
||||||
# inputs['images'][0].shape: torch.Size([3, 490, 490])
|
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
|
"input_ids": inputs["input_ids"].unsqueeze(0).to("cuda"),
|
||||||
"token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"),
|
"token_type_ids": inputs['token_type_ids'].unsqueeze(0).to("cuda"),
|
||||||
|
@ -352,15 +415,8 @@ class CogVLMManager(BaseModelWrapper):
|
||||||
"output_hidden_states": True,
|
"output_hidden_states": True,
|
||||||
"return_dict": True
|
"return_dict": True
|
||||||
}
|
}
|
||||||
# inputs['input_ids'].shape: torch.Size([1, 1259])
|
|
||||||
# inputs['attention_mask'].shape: torch.Size([1, 1259])
|
|
||||||
# inputs['images'][0][0].shape: torch.Size([3, 490, 490])
|
|
||||||
# len(inputs['images'][0]): 1
|
|
||||||
# len(inputs['images'][0][0]): 3
|
|
||||||
|
|
||||||
outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
outputs = self.model.generate(**inputs, **gen_kwargs, force_words_ids=force_words_ids, bad_words_ids=bad_words_ids)
|
||||||
#print(f"type of outputs: {type(outputs)}, outputs shape: {outputs.shape}")
|
|
||||||
#print(f"type of hidden_states: {type(hidden_states)}, outputs shape: {hidden_states.shape}")
|
|
||||||
|
|
||||||
len_inputs = inputs['input_ids'].shape[1]
|
len_inputs = inputs['input_ids'].shape[1]
|
||||||
outputs_without_prompt = outputs[:, len_inputs:]
|
outputs_without_prompt = outputs[:, len_inputs:]
|
||||||
|
@ -369,12 +425,22 @@ class CogVLMManager(BaseModelWrapper):
|
||||||
return caption
|
return caption
|
||||||
|
|
||||||
def get_model_wrapper(model_name: str):
|
def get_model_wrapper(model_name: str):
|
||||||
if "moai" in model_name:
|
match model_name.casefold():
|
||||||
return MoaiManager(model_name)
|
# case x if "moai" in x:
|
||||||
elif "llava" in model_name:
|
# #return MoaiManager(model_name)
|
||||||
|
# return None
|
||||||
|
case x if "llava" in x:
|
||||||
return XtunerLlavaModelManager(model_name)
|
return XtunerLlavaModelManager(model_name)
|
||||||
else:
|
case "thudm/glm-4v-9b":
|
||||||
|
return CogGLMManager(model_name)
|
||||||
|
case "thudm/cogvlm2-llama3-chat-19b":
|
||||||
return CogVLMManager(model_name)
|
return CogVLMManager(model_name)
|
||||||
|
case x if x in ["thudm/cogvlm-chat-hf","thudm/cogagent-chat-hf"]:
|
||||||
|
return CogVLMManager(model_name)
|
||||||
|
case None:
|
||||||
|
return CogVLMManager(model_name)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Model {model_name} not supported")
|
||||||
|
|
||||||
def get_inputs_dict(inputs):
|
def get_inputs_dict(inputs):
|
||||||
inputs = {
|
inputs = {
|
||||||
|
@ -518,6 +584,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch processing. Does NOT work with COG! (def: 1)")
|
argparser.add_argument("--batch_size", type=int, default=1, help="Batch size for batch processing. Does NOT work with COG! (def: 1)")
|
||||||
argparser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
argparser.add_argument("--debug", action="store_true", help="Enable debug logging")
|
||||||
argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.")
|
argparser.add_argument("--disable_4bit", action="store_true", help="Disables 4bit inference for compatibility or experimentation. Bad for VRAM, fallback is bf16.")
|
||||||
|
argparser.add_argument("--dtype", choices=["auto","fp16","bf16","nf4","fp4"], default="auto", help="Data type for inference (def: auto, see docs)")
|
||||||
argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling")
|
argparser.add_argument("--temp", type=float, default=None, help="Temperature for sampling")
|
||||||
argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)")
|
argparser.add_argument("--num_beams", type=int, default=2, help="Number of beams for beam search, default 1 (off)")
|
||||||
argparser.add_argument("--top_k", type=int, default=None, help="Top-k, filter k highest probability tokens before sampling")
|
argparser.add_argument("--top_k", type=int, default=None, help="Top-k, filter k highest probability tokens before sampling")
|
||||||
|
@ -540,7 +607,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.")
|
argparser.add_argument("--starts_with", type=str, default=None, help="Force start words on the output caption.")
|
||||||
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
|
argparser.add_argument("--remove_starts_with", action="store_true", help="Removes the starts_with words from the output caption.")
|
||||||
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
|
argparser.add_argument("--append_log", action="store_true", help="Sets logging to append mode.")
|
||||||
argparser.add_argument("--model", type=str, default="THUDM/cogvlm-chat-hf", help="Model to use for captioning.")
|
argparser.add_argument("--model", type=str, default=None, help="Model to use for captioning.")
|
||||||
argparser.add_argument("--min_pixels", type=int, default=1, help="Minimum total pixel size to caption, under the limit will be skipped")
|
argparser.add_argument("--min_pixels", type=int, default=1, help="Minimum total pixel size to caption, under the limit will be skipped")
|
||||||
args, unknown_args = argparser.parse_known_args()
|
args, unknown_args = argparser.parse_known_args()
|
||||||
|
|
||||||
|
|
|
@ -8,13 +8,23 @@ It is capable of naming and identifying things with proper nouns and has a large
|
||||||
|
|
||||||
<a href="https://colab.research.google.com/github/nawnie/EveryDream2trainer/blob/main/CaptionCog.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
<a href="https://colab.research.google.com/github/nawnie/EveryDream2trainer/blob/main/CaptionCog.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
|
||||||
|
|
||||||
|
Both the ([Vicuna-based](https://huggingface.co/THUDM/cogvlm-chat-hf)) and ([Llama3-based](https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B)) models are supported.
|
||||||
|
|
||||||
|
Choose these by using one of these two CLI args:
|
||||||
|
|
||||||
|
--model THUDM/cogvlm-chat-hf
|
||||||
|
|
||||||
|
--model THUDM/cogvlm2-llama3-chat-19B
|
||||||
|
|
||||||
|
The script uses the Vicuna model (first) by default if no `--model` arg is specified.
|
||||||
|
|
||||||
## Llava update
|
## Llava update
|
||||||
|
|
||||||
This script now (confusiningly) supports (Xtuner's Llava Llama3 8b v1.1)[https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main].
|
This script now (confusiningly) supports (Xtuner's Llava Llama3 8b v1.1)[https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main].
|
||||||
|
|
||||||
To use, add `--model "https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main"` to your command line.
|
To use, add `--model "https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers/tree/main"` to your command line.
|
||||||
|
|
||||||
This is a work in progress. So far it seems bad_words do not work.
|
When using Llava, the script will perform some clean-up operations to remove some less-than-useful language from the caption because the bad_words part of the Hugginface Transformers API is not supported by Llava.
|
||||||
|
|
||||||
## Basics
|
## Basics
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ LABEL org.opencontainers.image.licenses="AGPL-3.0-only"
|
||||||
ARG DEBIAN_FRONTEND=noninteractive
|
ARG DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
||||||
# Don't write .pyc bytecode
|
# Don't write .pyc bytecode
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
# ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
|
||||||
# Create workspace working directory
|
# Create workspace working directory
|
||||||
RUN mkdir /build
|
RUN mkdir /build
|
||||||
|
@ -49,7 +49,7 @@ ENV DEBIAN_FRONTEND noninteractive\
|
||||||
ENV PYTHONUNBUFFERED=1
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
# Don't write .pyc bytecode
|
# Don't write .pyc bytecode
|
||||||
ENV PYTHONDONTWRITEBYTECODE=1
|
# ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
|
||||||
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||||
|
@ -65,7 +65,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||||
echo "en_US.UTF-8 UTF-8" > /etc/locale.gen
|
echo "en_US.UTF-8 UTF-8" > /etc/locale.gen
|
||||||
|
|
||||||
# Install runpodctl
|
# Install runpodctl
|
||||||
RUN wget https://github.com/runpod/runpodctl/releases/download/v1.9.0/runpodctl-linux-amd -O runpodctl && \
|
RUN wget https://github.com/runpod/runpodctl/releases/download/v1.14.3/runpodctl-linux-amd64 -O runpodctl && \
|
||||||
chmod a+x runpodctl && \
|
chmod a+x runpodctl && \
|
||||||
mv runpodctl /usr/local/bin
|
mv runpodctl /usr/local/bin
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
diffusers[torch]>=0.21.4
|
diffusers[torch]>=0.27.2
|
||||||
ninja
|
ninja
|
||||||
numpy
|
numpy
|
||||||
omegaconf==2.2.3
|
omegaconf==2.2.3
|
||||||
|
|
|
@ -20,3 +20,4 @@ prodigyopt
|
||||||
torchsde
|
torchsde
|
||||||
peft==0.9.0
|
peft==0.9.0
|
||||||
unidecode
|
unidecode
|
||||||
|
tiktoken
|
|
@ -23,3 +23,4 @@ wandb
|
||||||
colorama
|
colorama
|
||||||
safetensors
|
safetensors
|
||||||
torchsde
|
torchsde
|
||||||
|
tiktoken
|
|
@ -26,6 +26,7 @@ def main(args):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"FAILED: {path}")
|
print(f"FAILED: {path}")
|
||||||
failed.append((path,e))
|
failed.append((path,e))
|
||||||
|
|
||||||
if not failed:
|
if not failed:
|
||||||
print("No errors found")
|
print("No errors found")
|
||||||
else:
|
else:
|
||||||
|
@ -33,7 +34,6 @@ def main(args):
|
||||||
for path, e in failed:
|
for path, e in failed:
|
||||||
print(f"FAILED: {path} {e}")
|
print(f"FAILED: {path} {e}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
print("This script checks that all images in a directory are valid.")
|
print("This script checks that all images in a directory are valid.")
|
||||||
print("If any errors occur, they will be printed out at the end.")
|
print("If any errors occur, they will be printed out at the end.")
|
||||||
|
|
|
@ -0,0 +1,88 @@
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
import torch
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
import os
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
|
def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen,
|
||||||
|
steps: int = 30, batch_size: int = 1):
|
||||||
|
"""
|
||||||
|
generates a single sample at a given cfg scale and saves it to disk
|
||||||
|
"""
|
||||||
|
with autocast():
|
||||||
|
images = pipe(prompt,
|
||||||
|
num_inference_steps=steps,
|
||||||
|
num_images_per_prompt=batch_size,
|
||||||
|
guidance_scale=cfg,
|
||||||
|
generator=gen,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
).images
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
def generate_simple(prompt, model):
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda")
|
||||||
|
images = __generate_sample(pipe, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)
|
||||||
|
return images[0]
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
argparser = argparse.ArgumentParser()
|
||||||
|
argparser.add_argument("--epochs", type=int, default=60)
|
||||||
|
args = argparser.parse_args()
|
||||||
|
epochs = args.epochs
|
||||||
|
|
||||||
|
path = "/mnt/nvme/mt/val"
|
||||||
|
|
||||||
|
model = None
|
||||||
|
if epochs == 100:
|
||||||
|
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep100-gs159000"
|
||||||
|
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep100-gs159000"
|
||||||
|
elif epochs == 80:
|
||||||
|
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep80-gs127200"
|
||||||
|
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep80-gs127200"
|
||||||
|
elif epochs == 60:
|
||||||
|
model1 = "/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400"
|
||||||
|
model2 = "/mnt/q/monotype/kanji_nov2023_shortcap-20231129-152030/ckpts/kanji_nov2023_shortcap-ep60-gs95400"
|
||||||
|
else:
|
||||||
|
raise ValueError("epochs must be 100, 80, or 60")
|
||||||
|
|
||||||
|
pipe1 = StableDiffusionPipeline.from_pretrained(model1).to("cuda")
|
||||||
|
pipe2 = StableDiffusionPipeline.from_pretrained(model2).to("cuda")
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(path):
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".txt") and not file.endswith("file_list.txt"):
|
||||||
|
txt_path = os.path.join(root, file)
|
||||||
|
with open(txt_path, "r", encoding="utf-8") as f:
|
||||||
|
prompt = f.readline()
|
||||||
|
generated_image1 = __generate_sample(pipe1, prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0]
|
||||||
|
short_prompt = prompt.split(",")[0]
|
||||||
|
generated_image2 = __generate_sample(pipe2, short_prompt, cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)[0]
|
||||||
|
print(short_prompt)
|
||||||
|
gt_path = txt_path.replace(".txt", ".png")
|
||||||
|
print(f"Loading gt_path {gt_path}")
|
||||||
|
|
||||||
|
ground_truth_image = Image.open(gt_path)
|
||||||
|
ground_truth_image = ground_truth_image.resize((512, 512))
|
||||||
|
|
||||||
|
combined_image = Image.new("RGB", (1536, 576), color=(96, 96, 96))
|
||||||
|
combined_image.paste(ground_truth_image, (0, 0))
|
||||||
|
combined_image.paste(generated_image1, (512, 0))
|
||||||
|
combined_image.paste(generated_image2, (1024, 0))
|
||||||
|
|
||||||
|
draw = ImageDraw.Draw(combined_image)
|
||||||
|
font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 18)
|
||||||
|
draw.text((0, 510), f"epochs={epochs}", font=font)
|
||||||
|
draw.text((200, 510), "↑ ground truth ↑", font=font)
|
||||||
|
draw.text((650, 510), "↑ trained&generated full caption↑", font=font)
|
||||||
|
draw.text((1140, 510), "↑ trained&generated short caption ↑", font=font)
|
||||||
|
font = ImageFont.truetype("/mnt/nvme/mt/NotoSansCJK-Bold.ttc", 24)
|
||||||
|
draw.text((100, 536), prompt, font=font)
|
||||||
|
draw.text((1240, 537), short_prompt, font=font)
|
||||||
|
|
||||||
|
output_path = os.path.join("/mnt/nvme/mt", str(epochs), f"{file}_compare.png")
|
||||||
|
print(f"Saving to {output_path}")
|
||||||
|
combined_image.save(output_path)
|
|
@ -0,0 +1,66 @@
|
||||||
|
"""
|
||||||
|
Copyright [2022] Victor C Hall
|
||||||
|
|
||||||
|
Licensed under the GNU Affero General Public License;
|
||||||
|
You may not use this code except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
## Script to move random files from source_root to destination_root for use as validation split
|
||||||
|
## chooses (num_pairs_to_move) image/caption pairs from each subfolder in source_root and moves them to destination_root, preserving subfolder structure
|
||||||
|
## also creates a validation_captions.txt file in destination_root for inference testing later to see how model performs on unseen data
|
||||||
|
## only works on (png|jpg) + txt pairs. Will break if there is no .txt file or images are other extensions
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import random
|
||||||
|
|
||||||
|
source_root = "/mnt/nvme/mydata_train"
|
||||||
|
destination_root = "/mnt/nvme/mydata_val"
|
||||||
|
|
||||||
|
num_pairs_to_move = 3 # TODO: also do % based moving instead of fixed number
|
||||||
|
|
||||||
|
def move_random_file_pairs(source_folder, destination_folder):
|
||||||
|
with open(os.path.join(destination_folder, "validation_captions.txt"), "w") as f:
|
||||||
|
for subdir, dirs, files in os.walk(source_folder):
|
||||||
|
for dir in dirs:
|
||||||
|
source_subfolder = os.path.join(subdir, dir)
|
||||||
|
destination_subfolder = os.path.join(destination_folder, dir)
|
||||||
|
|
||||||
|
if not os.path.exists(destination_subfolder):
|
||||||
|
os.makedirs(destination_subfolder)
|
||||||
|
|
||||||
|
file_list = [f for f in os.listdir(source_subfolder) if f.endswith((".png")) or f.endswith((".jpg"))]
|
||||||
|
|
||||||
|
if len(file_list) >= num_pairs_to_move:
|
||||||
|
random.shuffle(file_list)
|
||||||
|
|
||||||
|
for i in range(num_pairs_to_move):
|
||||||
|
file_name = file_list[i]
|
||||||
|
source_file = os.path.join(source_subfolder, file_name)
|
||||||
|
destination_file = os.path.join(destination_subfolder, file_name)
|
||||||
|
|
||||||
|
caption_file_name = os.path.splitext(file_name)[0] + ".txt"
|
||||||
|
caption_source_file = os.path.join(source_subfolder, caption_file_name)
|
||||||
|
caption_destination_file = os.path.join(destination_subfolder, caption_file_name)
|
||||||
|
|
||||||
|
with open(caption_source_file, "r") as caption_source:
|
||||||
|
caption = caption_source.readline()
|
||||||
|
f.write(caption)
|
||||||
|
f.write("\n")
|
||||||
|
print(f"Moving {caption_source_file} to {caption_destination_file}")
|
||||||
|
print(f"Moving {source_file} to {destination_file}")
|
||||||
|
print(f"Caption: {caption}\n")
|
||||||
|
shutil.move(source_file, destination_file)
|
||||||
|
shutil.move(caption_source_file, caption_destination_file)
|
||||||
|
|
||||||
|
|
||||||
|
move_random_file_pairs(source_root, destination_root)
|
|
@ -0,0 +1,113 @@
|
||||||
|
from diffusers import StableDiffusionPipeline
|
||||||
|
import torch
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
import os
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
def __generate_sample(pipe: StableDiffusionPipeline, prompt, cfg: float, height: int, width: int, gen,
|
||||||
|
steps: int = 30, batch_size: int = 1):
|
||||||
|
"""
|
||||||
|
generates a single sample at a given cfg scale and saves it to disk
|
||||||
|
"""
|
||||||
|
with autocast():
|
||||||
|
images = pipe(prompt,
|
||||||
|
num_inference_steps=steps,
|
||||||
|
num_images_per_prompt=batch_size,
|
||||||
|
guidance_scale=cfg,
|
||||||
|
generator=gen,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
).images
|
||||||
|
|
||||||
|
return images
|
||||||
|
|
||||||
|
def simple():
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained("/mnt/nvme/ed2old/logs/mt_kanji-20231125-185312/ckpts/mt_kanji-ep60-gs95400").to("cuda")
|
||||||
|
images = __generate_sample(pipe, "bicycle", cfg=7.5, height=512, width=512, gen=None, steps=40, batch_size=1)
|
||||||
|
images[0].save("test.png")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
#simple()
|
||||||
|
argparser = argparse.ArgumentParser()
|
||||||
|
argparser.add_argument("--prompt_file", type=str, required=False)
|
||||||
|
argparser.add_argument("--val_data_path", type=str, default="/mnt/nvme/mt/val", required=False)
|
||||||
|
argparser.add_argument("--models", nargs="+", help="names of models")
|
||||||
|
args = argparser.parse_args()
|
||||||
|
args.val_data_path = "/mnt/nvme/mt/val/00b42"
|
||||||
|
|
||||||
|
W = 512
|
||||||
|
H = 512
|
||||||
|
BATCH_SIZE = 4
|
||||||
|
|
||||||
|
print(f"Generating grid image for {len(args.models)} models and {args.prompt_file}")
|
||||||
|
#print each args.models
|
||||||
|
print("Models:")
|
||||||
|
for m in args.models:
|
||||||
|
print(f" {m}")
|
||||||
|
|
||||||
|
# with open(args.prompt_file, "r") as f:
|
||||||
|
# prompt_master_list = []
|
||||||
|
# for x, line in enumerate(f):
|
||||||
|
# prompt_master_list.append(line.strip())
|
||||||
|
|
||||||
|
# open the txt files in args.val_data_path
|
||||||
|
prompt_master_list = {}
|
||||||
|
for f in os.listdir(args.val_data_path):
|
||||||
|
if f.endswith(".txt"):
|
||||||
|
txt_path = os.path.join(args.val_data_path, f)
|
||||||
|
with open(os.path.join(args.val_data_path, f), "r", encoding="utf-8") as f2:
|
||||||
|
img_path = os.path.splitext(f)[0] + ".png"
|
||||||
|
img_path = os.path.join(args.val_data_path, img_path)
|
||||||
|
prompt_master_list[img_path] = f2.readline().strip()
|
||||||
|
|
||||||
|
print(f"Found {len(prompt_master_list)} images in {args.val_data_path}")
|
||||||
|
print(f"First 10 images: {list(prompt_master_list.values())[:10]}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
num_lines = len(prompt_master_list)
|
||||||
|
grid_h = (num_lines + 1) * W # num images plus blank for left column labels
|
||||||
|
grid_w = (1 + len(args.models)) * H # num models plus blank for top row labels
|
||||||
|
grid_img = Image.new("RGB", (grid_w, grid_h))
|
||||||
|
|
||||||
|
#num_iterations = len(prompt_master_list) // BATCH_SIZE + (len(prompt_master_list) % BATCH_SIZE > 0)
|
||||||
|
|
||||||
|
chunked_dict_list = []
|
||||||
|
chunk = {}
|
||||||
|
for key, value in prompt_master_list.items():
|
||||||
|
chunk[key] = value
|
||||||
|
if len(chunk) == BATCH_SIZE:
|
||||||
|
chunked_dict_list.append(chunk)
|
||||||
|
chunk = {}
|
||||||
|
|
||||||
|
# Append any remaining items if the total number of items is not a multiple of chunk_size
|
||||||
|
if chunk:
|
||||||
|
chunked_dict_list.append(chunk)
|
||||||
|
|
||||||
|
# Iterate through the chunks
|
||||||
|
for i, chunk in enumerate(chunked_dict_list):
|
||||||
|
print(f"Chunk {i + 1}: {chunk}")
|
||||||
|
exit()
|
||||||
|
|
||||||
|
for i_m, model in enumerate(args.models):
|
||||||
|
for j_p in range(chunked_dict_list):
|
||||||
|
start_index = j_p * BATCH_SIZE
|
||||||
|
end_index = (j_p + 1) * BATCH_SIZE
|
||||||
|
current_prompts = prompt_master_list[start_index:end_index]
|
||||||
|
|
||||||
|
print(f"{model}: {current_prompts}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if True:
|
||||||
|
pipe = StableDiffusionPipeline.from_pretrained(model).to("cuda")
|
||||||
|
seed_generator = torch.Generator(pipe.device).manual_seed(555)
|
||||||
|
images = __generate_sample(pipe, current_prompts, cfg=7.5, height=512, width=512, gen=seed_generator, steps=40, batch_size=BATCH_SIZE)
|
||||||
|
# paste each image into the grid starting from H,W and incrementing by W
|
||||||
|
for k, k_img in enumerate(images):
|
||||||
|
k_img.save(f"tmp/{i_m}_{k}.png")
|
||||||
|
grid_img.paste(k_img, (W+k*W, H+H*i_m))
|
||||||
|
# save the grid image
|
||||||
|
grid_img.save(f"tmp/grid.png")
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
"data_root": "/mnt/q/training_samples/ff7r/man",
|
"data_root": "/mnt/q/training_samples/ff7r/man",
|
||||||
"disable_amp": false,
|
"disable_amp": false,
|
||||||
"disable_textenc_training": false,
|
"disable_textenc_training": false,
|
||||||
|
"embedding_perturbation": 0.0,
|
||||||
"flip_p": 0.0,
|
"flip_p": 0.0,
|
||||||
"gpuid": 0,
|
"gpuid": 0,
|
||||||
"gradient_checkpointing": true,
|
"gradient_checkpointing": true,
|
||||||
|
@ -28,8 +29,7 @@
|
||||||
"save_ckpts_from_n_epochs": 0,
|
"save_ckpts_from_n_epochs": 0,
|
||||||
"save_every_n_epochs": 20,
|
"save_every_n_epochs": 20,
|
||||||
"save_optimizer": false,
|
"save_optimizer": false,
|
||||||
"scale_lr": false,
|
"seed": -1,
|
||||||
"seed": 555,
|
|
||||||
"timestep_start": 0,
|
"timestep_start": 0,
|
||||||
"timestep_end": 1000,
|
"timestep_end": 1000,
|
||||||
"shuffle_tags": false,
|
"shuffle_tags": false,
|
||||||
|
|
13
train.py
13
train.py
|
@ -733,6 +733,7 @@ def main(args):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print()
|
print()
|
||||||
|
# currently broken on most systems?
|
||||||
#unet = torch.compile(unet, mode="max-autotune")
|
#unet = torch.compile(unet, mode="max-autotune")
|
||||||
#text_encoder = torch.compile(text_encoder, mode="max-autotune")
|
#text_encoder = torch.compile(text_encoder, mode="max-autotune")
|
||||||
#vae = torch.compile(vae, mode="max-autotune")
|
#vae = torch.compile(vae, mode="max-autotune")
|
||||||
|
@ -833,7 +834,6 @@ def main(args):
|
||||||
logging.info(
|
logging.info(
|
||||||
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_update_interval: {args.ema_update_interval}, ema_device: {args.ema_device}.")
|
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_update_interval: {args.ema_update_interval}, ema_device: {args.ema_device}.")
|
||||||
|
|
||||||
|
|
||||||
ed_optimizer = EveryDreamOptimizer(args,
|
ed_optimizer = EveryDreamOptimizer(args,
|
||||||
optimizer_config,
|
optimizer_config,
|
||||||
text_encoder,
|
text_encoder,
|
||||||
|
@ -931,7 +931,7 @@ def main(args):
|
||||||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||||
|
|
||||||
# actual prediction function - shared between train and validate
|
# actual prediction function - shared between train and validate
|
||||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None):
|
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False, loss_scale=None, embedding_perturbation=0.0):
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with autocast(enabled=args.amp):
|
with autocast(enabled=args.amp):
|
||||||
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||||
|
@ -968,6 +968,11 @@ def main(args):
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||||
|
|
||||||
|
# https://arxiv.org/pdf/2405.20494
|
||||||
|
perturbation_deviation = embedding_perturbation / math.sqrt(encoder_hidden_states.shape[2])
|
||||||
|
perturbation_delta = torch.randn_like(encoder_hidden_states) * (perturbation_deviation)
|
||||||
|
encoder_hidden_states = encoder_hidden_states + perturbation_delta
|
||||||
|
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
if noise_scheduler.config.prediction_type == "epsilon":
|
if noise_scheduler.config.prediction_type == "epsilon":
|
||||||
|
@ -1195,7 +1200,8 @@ def main(args):
|
||||||
batch["tokens"],
|
batch["tokens"],
|
||||||
args.zero_frequency_noise_ratio,
|
args.zero_frequency_noise_ratio,
|
||||||
return_loss=True,
|
return_loss=True,
|
||||||
loss_scale=batch["loss_scale"])
|
loss_scale=batch["loss_scale"],
|
||||||
|
embedding_perturbation=args.embedding_perturbation)
|
||||||
|
|
||||||
del target, model_pred
|
del target, model_pred
|
||||||
|
|
||||||
|
@ -1362,6 +1368,7 @@ if __name__ == "__main__":
|
||||||
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables automatic mixed precision (def: False)")
|
argparser.add_argument("--disable_amp", action="store_true", default=False, help="disables automatic mixed precision (def: False)")
|
||||||
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
argparser.add_argument("--disable_textenc_training", action="store_true", default=False, help="disables training of text encoder (def: False)")
|
||||||
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||||
|
argparser.add_argument("--embedding_perturbation", type=float, default=0.0, help="random perturbation of text embeddings (def: 0.0)")
|
||||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
||||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids")
|
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids")
|
||||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||||
|
|
|
@ -26,6 +26,7 @@ pip install prodigyopt
|
||||||
pip install torchsde
|
pip install torchsde
|
||||||
pip install peft>=0.9.0
|
pip install peft>=0.9.0
|
||||||
pip install unidecode
|
pip install unidecode
|
||||||
|
pip install tiktoken
|
||||||
python utils/get_yamls.py
|
python utils/get_yamls.py
|
||||||
GOTO :eof
|
GOTO :eof
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue