Merge pull request #204 from victorchall/plugins

WIP plugins
This commit is contained in:
Victor Hall 2023-06-29 18:13:33 -04:00 committed by GitHub
commit e00b7b571a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 252 additions and 14 deletions

198
caption_fl.py Normal file
View File

@ -0,0 +1,198 @@
"""
Copyright [2022-2023] Victor C Hall
Licensed under the GNU Affero General Public License;
You may not use this code except in compliance with the License.
You may obtain a copy of the License at
https://www.gnu.org/licenses/agpl-3.0.en.html
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
from PIL import Image
import argparse
import requests
from transformers import Blip2Processor, Blip2ForConditionalGeneration, GitProcessor, GitForCausalLM, AutoModel, AutoProcessor
from huggingface_hub import hf_hub_download
from open_flamingo import create_model_and_transforms
import torch
from pynvml import *
import time
from colorama import Fore, Style
SUPPORTED_EXT = [".jpg", ".png", ".jpeg", ".bmp", ".jfif", ".webp"]
def get_gpu_memory_map():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
nvmlShutdown()
return info.used/1024/1024
def remove_duplicates(string):
words = string.split(', ') # Split the string into individual words
unique_words = []
for word in words:
if word not in unique_words:
unique_words.append(word)
else:
break # Stop appending words once a duplicate is found
return ', '.join(unique_words)
def get_examples(example_root, image_processor):
examples = []
for root, dirs, files in os.walk(example_root):
for file in files:
ext = os.path.splitext(file)[-1].lower()
if ext in SUPPORTED_EXT:
#get .txt file of same base name
txt_file = os.path.splitext(file)[0] + ".txt"
with open(os.path.join(root, txt_file), 'r') as f:
caption = f.read()
image = Image.open(os.path.join(root, file))
vision_x = [image_processor(image).unsqueeze(0)]
#vision_x = torch.cat(vision_x, dim=0)
#vision_x = vision_x.unsqueeze(1).unsqueeze(0)
examples.append((caption, vision_x))
for x in examples:
print(f" ** Example: {x[0]}")
return examples
def main(args):
device = "cuda" if torch.cuda.is_available() and not args.force_cpu else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32
if args.prompt:
prompt = args.prompt
else:
prompt = "<image>: "
print(f" using prompt: {prompt}")
if "mpt7b" in args.model:
lang_encoder_path="anas-awadalla/mpt-7b"
tokenizer_path="anas-awadalla/mpt-7b"
elif "mpt1b" in args.model:
lang_encoder_path="anas-awadalla/mpt-1b"
tokenizer_path="anas-awadalla/mpt-1b"
model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path=lang_encoder_path,
tokenizer_path=tokenizer_path,
cross_attn_every_n_layers=1,
)
tokenizer.padding_side = "left"
checkpoint_path = hf_hub_download(args.model, "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)
print(f"GPU memory used, before loading model: {get_gpu_memory_map()} MB")
model.to(0, dtype=dtype)
print(f"GPU memory used, after loading model: {get_gpu_memory_map()} MB")
# examples give few shot learning for captioning the novel image
examples = get_examples(args.example_root, image_processor)
prompt = ""
output_prompt = "Output:"
per_image_prompt = "<image> " + output_prompt
for example in iter(examples):
prompt += f"{per_image_prompt}{example[0]}<|endofchunk|>"
prompt += per_image_prompt # prepare for novel example
prompt = prompt.replace("\n", "") # in case captions had newlines
print(f" \n** Final full prompt with example pairs: {prompt}")
# os.walk all files in args.data_root recursively
for root, dirs, files in os.walk(args.data_root):
for file in files:
#get file extension
ext = os.path.splitext(file)[1]
if ext.lower() in SUPPORTED_EXT:
start_time = time.time()
full_file_path = os.path.join(root, file)
image = Image.open(full_file_path)
vision_x = [vx[1][0] for vx in examples]
vision_x.append(image_processor(image).unsqueeze(0))
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
vision_x = vision_x.to(device, dtype=dtype)
lang_x = tokenizer(
[prompt], # blank for image captioning
return_tensors="pt",
)
lang_x.to(device)
input_ids = lang_x["input_ids"].to(device)
with torch.cuda.amp.autocast(dtype=dtype):
generated_text = model.generate(
vision_x=vision_x,
lang_x=input_ids,
attention_mask=lang_x["attention_mask"],
max_new_tokens=args.max_new_tokens,
min_new_tokens=args.min_new_tokens,
num_beams=args.num_beams,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
)
del vision_x
del lang_x
# trim and clean
generated_text = tokenizer.decode(generated_text[0][len(input_ids[0]):], skip_special_tokens=True)
generated_text = generated_text.split(output_prompt)[0]
generated_text = remove_duplicates(generated_text)
exec_time = time.time() - start_time
print(f"* Caption: {generated_text}")
print(f" Time for last caption: {exec_time} sec. GPU memory used: {get_gpu_memory_map()} MB")
name = os.path.splitext(full_file_path)[0]
if not os.path.exists(name):
with open(f"{name}.txt", "w") as f:
f.write(generated_text)
if __name__ == "__main__":
print(f"Available models:")
print(f" openflamingo/OpenFlamingo-9B-vitl-mpt7b (default)")
print(f" openflamingo/OpenFlamingo-3B-vitl-mpt1b")
print(f" openflamingo/OpenFlamingo-4B-vitl-rpj3b")
parser = argparse.ArgumentParser()
parser.add_argument("--data_root", type=str, default="input", help="Path to images")
parser.add_argument("--example_root", type=str, default="examples", help="Path to 2-3 precaptioned images to guide generation")
parser.add_argument("--model", type=str, default="openflamingo/OpenFlamingo-9B-vitl-mpt7b", help="Model name or path")
parser.add_argument("--force_cpu", action="store_true", default=False, help="force using CPU even if GPU is available")
parser.add_argument("--min_new_tokens", type=int, default=20, help="minimum number of tokens to generate")
parser.add_argument("--max_new_tokens", type=int, default=50, help="maximum number of tokens to generate")
parser.add_argument("--num_beams", type=int, default=8, help="number of beams, more is more accurate but slower")
parser.add_argument("--prompt", type=str, default="Output: ", help="prompt to use for generation, default is 'Output: '")
parser.add_argument("--temperature", type=float, default=1.0, help="temperature for sampling, 1.0 is default")
parser.add_argument("--top_k", type=int, default=0, help="top_k sampling, 0 is default")
parser.add_argument("--top_p", type=float, default=0.9, help="top_p sampling, 0.9 is default")
parser.add_argument("--repetition_penalty", type=float, default=1.0, help="repetition penalty, 1.0 is default")
parser.add_argument("--length_penalty", type=float, default=1.0, help="length penalty, 1.0 is default")
args = parser.parse_args()
print(f"** OPEN-FLAMINGO ** Captioning files in: {args.data_root}")
print(f"** Using model: {args.model}")
main(args)

Binary file not shown.

After

Width:  |  Height:  |  Size: 317 KiB

View File

@ -0,0 +1 @@
a red and blue 'Underground' sign found in London

Binary file not shown.

After

Width:  |  Height:  |  Size: 42 KiB

View File

@ -0,0 +1 @@
a white bowl filled with creamy hummus placed on a white countertop

View File

@ -194,7 +194,7 @@ class EveryDreamOptimizer():
return te_config, base_config
def create_lr_schedulers(self, args, optimizer_config):
def create_lr_schedulers(self, args, optimizer_config):
unet_config = optimizer_config["base"]
te_config = optimizer_config["text_encoder_overrides"]
@ -276,7 +276,7 @@ class EveryDreamOptimizer():
decouple = True # seems bad to turn off, dadapt_adam only
momentum = 0.0 # dadapt_sgd
no_prox = False # ????, dadapt_adan
growth_rate=float("inf") # dadapt
growth_rate=float("inf") # dadapt various, no idea what a sane default is
if local_optimizer_config is not None:
betas = local_optimizer_config.get("betas", betas)
@ -336,7 +336,7 @@ class EveryDreamOptimizer():
eps=epsilon, #unused for lion
d0=d0,
log_every=args.log_step,
growth_rate=1e5,
growth_rate=growth_rate,
decouple=decouple,
)
elif optimizer_name == "dadapt_adan":
@ -371,7 +371,7 @@ class EveryDreamOptimizer():
weight_decay=weight_decay,
d0=d0,
log_every=args.log_step,
growth_rate=float("inf"),
growth_rate=growth_rate,
)
else:

23
plugins/plugins.py Normal file
View File

@ -0,0 +1,23 @@
import argparse
import importlib
import logging
class BasePlugin:
def on_epoch_start(self, **kwargs):
pass
def on_epoch_end(self, **kwargs):
pass
class ExampleLoggingPlugin(BasePlugin):
def on_epoch_start(self, **kwargs):
logging.info(f"Epoch {kwargs['epoch']} starting")
def on_epoch_end(self, **kwargs):
logging.info(f"Epoch {kwargs['epoch']} finished")
def load_plugin(plugin_name):
module = importlib.import_module(plugin_name)
plugin_class = getattr(module, plugin_name)
if not issubclass(plugin_class, BasePlugin):
raise TypeError(f'{plugin_name} is not a valid plugin')
logging.info(f"Plugin {plugin_name} loaded")
return plugin_class()

View File

@ -20,3 +20,5 @@ OmegaConf==2.2.3
numpy==1.23.5
wandb
colorama
safetensors
open-flamingo==2.0.0

View File

@ -504,10 +504,12 @@ def main(args):
text_encoder = text_encoder.to(device, dtype=torch.float32)
try:
torch.compile(unet)
torch.compile(text_encoder)
torch.compile(vae)
logging.info("Successfully compiled models")
#unet = torch.compile(unet)
#text_encoder = torch.compile(text_encoder)
#vae = torch.compile(vae)
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.allow_tf32 = True
#logging.info("Successfully compiled models")
except Exception as ex:
logging.warning(f"Failed to compile model, continuing anyway, ex: {ex}")
pass
@ -746,10 +748,16 @@ def main(args):
_, batch = next(enumerate(train_dataloader))
generate_samples(global_step=0, batch=batch)
from plugins.base_plugin import load_plugin
plugins = [load_plugin(name) for name in args.plugins]
try:
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
for epoch in range(args.max_epochs):
for plugin in plugins:
plugin.on_epoch_start(epoch, global_step)
loss_epoch = []
epoch_start_time = time.time()
images_per_sec_log_step = []
@ -773,7 +781,7 @@ def main(args):
del target, model_pred
if batch["runt_size"] > 0:
loss_scale = batch["runt_size"] / args.batch_size
loss_scale = (batch["runt_size"] / args.batch_size)**1.5 # further discount runts by **1.5
loss = loss * loss_scale
ed_optimizer.step(loss, step, global_step)
@ -790,14 +798,14 @@ def main(args):
loss_epoch.append(loss_step)
if (global_step + 1) % args.log_step == 0:
loss_local = sum(loss_log_step) / len(loss_log_step)
loss_step = sum(loss_log_step) / len(loss_log_step)
lr_unet = ed_optimizer.get_unet_lr()
lr_textenc = ed_optimizer.get_textenc_lr()
loss_log_step = []
log_writer.add_scalar(tag="hyperparameter/lr unet", scalar_value=lr_unet, global_step=global_step)
log_writer.add_scalar(tag="hyperparameter/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_local, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_step, global_step=global_step)
sum_img = sum(images_per_sec_log_step)
avg = sum_img / len(images_per_sec_log_step)
@ -806,7 +814,7 @@ def main(args):
log_writer.add_scalar(tag="hyperparameter/grad scale", scalar_value=ed_optimizer.get_scale(), global_step=global_step)
log_writer.add_scalar(tag="performance/images per second", scalar_value=avg, global_step=global_step)
logs = {"loss/log_step": loss_local, "lr_unet": lr_unet, "lr_te": lr_textenc, "img/s": images_per_sec}
logs = {"loss/log_step": loss_step, "lr_unet": lr_unet, "lr_te": lr_textenc, "img/s": images_per_sec}
append_epoch_log(global_step=global_step, epoch_pbar=epoch_pbar, gpu=gpu, log_writer=log_writer, **logs)
torch.cuda.empty_cache()
@ -844,9 +852,11 @@ def main(args):
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
loss_local = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_local, global_step=global_step)
loss_epoch = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
for plugin in plugins:
plugin.on_epoch_end(epoch, global_step)
gc.collect()
# end of epoch
@ -914,6 +924,7 @@ if __name__ == "__main__":
argparser.add_argument("--no_prepend_last", action="store_true", help="Do not prepend 'last-' to the final checkpoint filename")
argparser.add_argument("--no_save_ckpt", action="store_true", help="Save only diffusers files, no .ckpts" )
argparser.add_argument("--optimizer_config", default="optimizer.json", help="Path to a JSON configuration file for the optimizer. Default is 'optimizer.json'")
argparser.add_argument('--plugins', nargs='+', help='Names of plugins to use')
argparser.add_argument("--project_name", type=str, default="myproj", help="Project name for logs and checkpoints, ex. 'tedbennett', 'superduperV1'")
argparser.add_argument("--resolution", type=int, default=512, help="resolution to train", choices=supported_resolutions)
argparser.add_argument("--resume_ckpt", type=str, required=not ('resume_ckpt' in args), default="sd_v1-5_vae.ckpt", help="The checkpoint to resume from, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1 ")

View File

@ -21,6 +21,8 @@ pip install numpy==1.23.5
pip install lion-pytorch
pip install compel~=1.1.3
pip install dadaptation
pip install safetensors
pip install open-flamingo==2.0.0
python utils/get_yamls.py
GOTO :eof