commit
e00b7b571a
|
@ -0,0 +1,198 @@
|
|||
"""
|
||||
Copyright [2022-2023] Victor C Hall
|
||||
|
||||
Licensed under the GNU Affero General Public License;
|
||||
You may not use this code except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from PIL import Image
|
||||
import argparse
|
||||
import requests
|
||||
from transformers import Blip2Processor, Blip2ForConditionalGeneration, GitProcessor, GitForCausalLM, AutoModel, AutoProcessor
|
||||
from huggingface_hub import hf_hub_download
|
||||
from open_flamingo import create_model_and_transforms
|
||||
|
||||
import torch
|
||||
from pynvml import *
|
||||
|
||||
import time
|
||||
from colorama import Fore, Style
|
||||
|
||||
|
||||
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
|
||||
|
||||
def get_gpu_memory_map():
|
||||
nvmlInit()
|
||||
handle = nvmlDeviceGetHandleByIndex(0)
|
||||
info = nvmlDeviceGetMemoryInfo(handle)
|
||||
nvmlShutdown()
|
||||
return info.used/1024/1024
|
||||
|
||||
def remove_duplicates(string):
|
||||
words = string.split(', ') # Split the string into individual words
|
||||
unique_words = []
|
||||
|
||||
for word in words:
|
||||
if word not in unique_words:
|
||||
unique_words.append(word)
|
||||
else:
|
||||
break # Stop appending words once a duplicate is found
|
||||
|
||||
return ', '.join(unique_words)
|
||||
|
||||
def get_examples(example_root, image_processor):
|
||||
examples = []
|
||||
for root, dirs, files in os.walk(example_root):
|
||||
for file in files:
|
||||
ext = os.path.splitext(file)[-1].lower()
|
||||
if ext in SUPPORTED_EXT:
|
||||
#get .txt file of same base name
|
||||
txt_file = os.path.splitext(file)[0] + ".txt"
|
||||
with open(os.path.join(root, txt_file), 'r') as f:
|
||||
caption = f.read()
|
||||
image = Image.open(os.path.join(root, file))
|
||||
vision_x = [image_processor(image).unsqueeze(0)]
|
||||
#vision_x = torch.cat(vision_x, dim=0)
|
||||
#vision_x = vision_x.unsqueeze(1).unsqueeze(0)
|
||||
examples.append((caption, vision_x))
|
||||
for x in examples:
|
||||
print(f" ** Example: {x[0]}")
|
||||
return examples
|
||||
|
||||
def main(args):
|
||||
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
|
||||
dtype = torch.bfloat16 if device == "cuda" else torch.float32
|
||||
|
||||
if args.prompt:
|
||||
prompt = args.prompt
|
||||
else:
|
||||
prompt = "<image>: "
|
||||
print(f" using prompt: {prompt}")
|
||||
|
||||
if "mpt7b" in args.model:
|
||||
lang_encoder_path="anas-awadalla/mpt-7b"
|
||||
tokenizer_path="anas-awadalla/mpt-7b"
|
||||
elif "mpt1b" in args.model:
|
||||
lang_encoder_path="anas-awadalla/mpt-1b"
|
||||
tokenizer_path="anas-awadalla/mpt-1b"
|
||||
|
||||
model, image_processor, tokenizer = create_model_and_transforms(
|
||||
clip_vision_encoder_path="ViT-L-14",
|
||||
clip_vision_encoder_pretrained="openai",
|
||||
lang_encoder_path=lang_encoder_path,
|
||||
tokenizer_path=tokenizer_path,
|
||||
cross_attn_every_n_layers=1,
|
||||
)
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
checkpoint_path = hf_hub_download(args.model, "checkpoint.pt")
|
||||
model.load_state_dict(torch.load(checkpoint_path), strict=False)
|
||||
print(f"GPU memory used, before loading model: {get_gpu_memory_map()} MB")
|
||||
model.to(0, dtype=dtype)
|
||||
print(f"GPU memory used, after loading model: {get_gpu_memory_map()} MB")
|
||||
|
||||
# examples give few shot learning for captioning the novel image
|
||||
examples = get_examples(args.example_root, image_processor)
|
||||
|
||||
prompt = ""
|
||||
output_prompt = "Output:"
|
||||
per_image_prompt = "<image> " + output_prompt
|
||||
|
||||
for example in iter(examples):
|
||||
prompt += f"{per_image_prompt}{example[0]}<|endofchunk|>"
|
||||
prompt += per_image_prompt # prepare for novel example
|
||||
prompt = prompt.replace("\n", "") # in case captions had newlines
|
||||
print(f" \n** Final full prompt with example pairs: {prompt}")
|
||||
|
||||
# os.walk all files in args.data_root recursively
|
||||
for root, dirs, files in os.walk(args.data_root):
|
||||
for file in files:
|
||||
#get file extension
|
||||
ext = os.path.splitext(file)[1]
|
||||
if ext.lower() in SUPPORTED_EXT:
|
||||
start_time = time.time()
|
||||
|
||||
full_file_path = os.path.join(root, file)
|
||||
image = Image.open(full_file_path)
|
||||
|
||||
vision_x = [vx[1][0] for vx in examples]
|
||||
vision_x.append(image_processor(image).unsqueeze(0))
|
||||
vision_x = torch.cat(vision_x, dim=0)
|
||||
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
|
||||
vision_x = vision_x.to(device, dtype=dtype)
|
||||
|
||||
lang_x = tokenizer(
|
||||
[prompt], # blank for image captioning
|
||||
return_tensors="pt",
|
||||
)
|
||||
lang_x.to(device)
|
||||
|
||||
input_ids = lang_x["input_ids"].to(device)
|
||||
|
||||
with torch.cuda.amp.autocast(dtype=dtype):
|
||||
generated_text = model.generate(
|
||||
vision_x=vision_x,
|
||||
lang_x=input_ids,
|
||||
attention_mask=lang_x["attention_mask"],
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
min_new_tokens=args.min_new_tokens,
|
||||
num_beams=args.num_beams,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
top_p=args.top_p,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
)
|
||||
del vision_x
|
||||
del lang_x
|
||||
|
||||
# trim and clean
|
||||
generated_text = tokenizer.decode(generated_text[0][len(input_ids[0]):], skip_special_tokens=True)
|
||||
generated_text = generated_text.split(output_prompt)[0]
|
||||
generated_text = remove_duplicates(generated_text)
|
||||
|
||||
exec_time = time.time() - start_time
|
||||
print(f"* Caption: {generated_text}")
|
||||
|
||||
print(f" Time for last caption: {exec_time} sec. GPU memory used: {get_gpu_memory_map()} MB")
|
||||
|
||||
name = os.path.splitext(full_file_path)[0]
|
||||
if not os.path.exists(name):
|
||||
with open(f"{name}.txt", "w") as f:
|
||||
f.write(generated_text)
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Available models:")
|
||||
print(f" openflamingo/OpenFlamingo-9B-vitl-mpt7b (default)")
|
||||
print(f" openflamingo/OpenFlamingo-3B-vitl-mpt1b")
|
||||
print(f" openflamingo/OpenFlamingo-4B-vitl-rpj3b")
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data_root", type=str, default="input", help="Path to images")
|
||||
parser.add_argument("--example_root", type=str, default="examples", help="Path to 2-3 precaptioned images to guide generation")
|
||||
parser.add_argument("--model", type=str, default="openflamingo/OpenFlamingo-9B-vitl-mpt7b", help="Model name or path")
|
||||
parser.add_argument("--force_cpu", action="store_true", default=False, help="force using CPU even if GPU is available")
|
||||
parser.add_argument("--min_new_tokens", type=int, default=20, help="minimum number of tokens to generate")
|
||||
parser.add_argument("--max_new_tokens", type=int, default=50, help="maximum number of tokens to generate")
|
||||
parser.add_argument("--num_beams", type=int, default=8, help="number of beams, more is more accurate but slower")
|
||||
parser.add_argument("--prompt", type=str, default="Output: ", help="prompt to use for generation, default is 'Output: '")
|
||||
parser.add_argument("--temperature", type=float, default=1.0, help="temperature for sampling, 1.0 is default")
|
||||
parser.add_argument("--top_k", type=int, default=0, help="top_k sampling, 0 is default")
|
||||
parser.add_argument("--top_p", type=float, default=0.9, help="top_p sampling, 0.9 is default")
|
||||
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty, 1.0 is default")
|
||||
parser.add_argument("--length_penalty", type=float, default=1.0, help="length penalty, 1.0 is default")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"** OPEN-FLAMINGO ** Captioning files in: {args.data_root}")
|
||||
print(f"** Using model: {args.model}")
|
||||
main(args)
|
Binary file not shown.
After Width: | Height: | Size: 317 KiB |
|
@ -0,0 +1 @@
|
|||
a red and blue 'Underground' sign found in London
|
BIN
examples/a white bowl filled with creamy hummus placed on a white countertop.jpg
Executable file
BIN
examples/a white bowl filled with creamy hummus placed on a white countertop.jpg
Executable file
Binary file not shown.
After Width: | Height: | Size: 42 KiB |
|
@ -0,0 +1 @@
|
|||
a white bowl filled with creamy hummus placed on a white countertop
|
|
@ -194,7 +194,7 @@ class EveryDreamOptimizer():
|
|||
|
||||
return te_config, base_config
|
||||
|
||||
def create_lr_schedulers(self, args, optimizer_config):
|
||||
def create_lr_schedulers(self, args, optimizer_config):
|
||||
unet_config = optimizer_config["base"]
|
||||
te_config = optimizer_config["text_encoder_overrides"]
|
||||
|
||||
|
@ -276,7 +276,7 @@ class EveryDreamOptimizer():
|
|||
decouple = True # seems bad to turn off, dadapt_adam only
|
||||
momentum = 0.0 # dadapt_sgd
|
||||
no_prox = False # ????, dadapt_adan
|
||||
growth_rate=float("inf") # dadapt
|
||||
growth_rate=float("inf") # dadapt various, no idea what a sane default is
|
||||
|
||||
if local_optimizer_config is not None:
|
||||
betas = local_optimizer_config.get("betas", betas)
|
||||
|
@ -336,7 +336,7 @@ class EveryDreamOptimizer():
|
|||
eps=epsilon, #unused for lion
|
||||
d0=d0,
|
||||
log_every=args.log_step,
|
||||
growth_rate=1e5,
|
||||
growth_rate=growth_rate,
|
||||
decouple=decouple,
|
||||
)
|
||||
elif optimizer_name == "dadapt_adan":
|
||||
|
@ -371,7 +371,7 @@ class EveryDreamOptimizer():
|
|||
weight_decay=weight_decay,
|
||||
d0=d0,
|
||||
log_every=args.log_step,
|
||||
growth_rate=float("inf"),
|
||||
growth_rate=growth_rate,
|
||||
)
|
||||
|
||||
else:
|
||||
|
|
|
@ -0,0 +1,23 @@
|
|||
import argparse
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
class BasePlugin:
|
||||
def on_epoch_start(self, **kwargs):
|
||||
pass
|
||||
def on_epoch_end(self, **kwargs):
|
||||
pass
|
||||
|
||||
class ExampleLoggingPlugin(BasePlugin):
|
||||
def on_epoch_start(self, **kwargs):
|
||||
logging.info(f"Epoch {kwargs['epoch']} starting")
|
||||
def on_epoch_end(self, **kwargs):
|
||||
logging.info(f"Epoch {kwargs['epoch']} finished")
|
||||
|
||||
def load_plugin(plugin_name):
|
||||
module = importlib.import_module(plugin_name)
|
||||
plugin_class = getattr(module, plugin_name)
|
||||
if not issubclass(plugin_class, BasePlugin):
|
||||
raise TypeError(f'{plugin_name} is not a valid plugin')
|
||||
logging.info(f"Plugin {plugin_name} loaded")
|
||||
return plugin_class()
|
|
@ -20,3 +20,5 @@ OmegaConf==2.2.3
|
|||
numpy==1.23.5
|
||||
wandb
|
||||
colorama
|
||||
safetensors
|
||||
open-flamingo==2.0.0
|
31
train.py
31
train.py
|
@ -504,10 +504,12 @@ def main(args):
|
|||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
try:
|
||||
torch.compile(unet)
|
||||
torch.compile(text_encoder)
|
||||
torch.compile(vae)
|
||||
logging.info("Successfully compiled models")
|
||||
#unet = torch.compile(unet)
|
||||
#text_encoder = torch.compile(text_encoder)
|
||||
#vae = torch.compile(vae)
|
||||
torch.set_float32_matmul_precision('high')
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
#logging.info("Successfully compiled models")
|
||||
except Exception as ex:
|
||||
logging.warning(f"Failed to compile model, continuing anyway, ex: {ex}")
|
||||
pass
|
||||
|
@ -746,10 +748,16 @@ def main(args):
|
|||
_, batch = next(enumerate(train_dataloader))
|
||||
generate_samples(global_step=0, batch=batch)
|
||||
|
||||
from plugins.base_plugin import load_plugin
|
||||
plugins = [load_plugin(name) for name in args.plugins]
|
||||
|
||||
try:
|
||||
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
||||
|
||||
for epoch in range(args.max_epochs):
|
||||
for plugin in plugins:
|
||||
plugin.on_epoch_start(epoch, global_step)
|
||||
|
||||
loss_epoch = []
|
||||
epoch_start_time = time.time()
|
||||
images_per_sec_log_step = []
|
||||
|
@ -773,7 +781,7 @@ def main(args):
|
|||
del target, model_pred
|
||||
|
||||
if batch["runt_size"] > 0:
|
||||
loss_scale = batch["runt_size"] / args.batch_size
|
||||
loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5
|
||||
loss = loss * loss_scale
|
||||
|
||||
ed_optimizer.step(loss, step, global_step)
|
||||
|
@ -790,14 +798,14 @@ def main(args):
|
|||
loss_epoch.append(loss_step)
|
||||
|
||||
if (global_step + 1) % args.log_step == 0:
|
||||
loss_local = sum(loss_log_step) / len(loss_log_step)
|
||||
loss_step = sum(loss_log_step) / len(loss_log_step)
|
||||
lr_unet = ed_optimizer.get_unet_lr()
|
||||
lr_textenc = ed_optimizer.get_textenc_lr()
|
||||
loss_log_step = []
|
||||
|
||||
log_writer.add_scalar(tag="hyperparameter/lr unet", scalar_value=lr_unet, global_step=global_step)
|
||||
log_writer.add_scalar(tag="hyperparameter/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_step, global_step=global_step)
|
||||
|
||||
sum_img = sum(images_per_sec_log_step)
|
||||
avg = sum_img / len(images_per_sec_log_step)
|
||||
|
@ -806,7 +814,7 @@ def main(args):
|
|||
log_writer.add_scalar(tag="hyperparameter/grad scale", scalar_value=ed_optimizer.get_scale(), global_step=global_step)
|
||||
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
|
||||
|
||||
logs = {"loss/log_step": loss_local, "lr_unet": lr_unet, "lr_te": lr_textenc, "img/s": images_per_sec}
|
||||
logs = {"loss/log_step": loss_step, "lr_unet": lr_unet, "lr_te": lr_textenc, "img/s": images_per_sec}
|
||||
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
@ -844,9 +852,11 @@ def main(args):
|
|||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
|
||||
|
||||
loss_local = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
|
||||
loss_epoch = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
|
||||
|
||||
for plugin in plugins:
|
||||
plugin.on_epoch_end(epoch, global_step)
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
|
@ -914,6 +924,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--no_prepend_last", action="store_true", help="Do not prepend 'last-' to the final checkpoint filename")
|
||||
argparser.add_argument("--no_save_ckpt", action="store_true", help="Save only diffusers files, no .ckpts" )
|
||||
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
|
||||
argparser.add_argument('--plugins', nargs='+', help='Names of plugins to use')
|
||||
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
|
||||
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
|
||||
argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")
|
||||
|
|
|
@ -21,6 +21,8 @@ pip install numpy==1.23.5
|
|||
pip install lion-pytorch
|
||||
pip install compel~=1.1.3
|
||||
pip install dadaptation
|
||||
pip install safetensors
|
||||
pip install open-flamingo==2.0.0
|
||||
python utils/get_yamls.py
|
||||
GOTO :eof
|
||||
|
||||
|
|
Loading…
Reference in New Issue