Merge branch 'victorchall:main' into main

This commit is contained in:
Alex 2023-11-17 20:52:34 +02:00 committed by GitHub
commit dcf2969640
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 75 additions and 33 deletions

View File

@ -87,3 +87,4 @@ Make sure to check out the [tools repo](https://github.com/victorchall/EveryDrea
[Contributing](doc/CONTRIBUTING.md)
[Citations and references](doc/CITATIONS.md)

40
doc/CITATIONS.md Normal file
View File

@ -0,0 +1,40 @@
Everydream 2 trainer is built using various open source technologies and packages.
This is not a thorough nor deep list, but is an opinionated list of research that is most proximal to this repo and interesting.
### Stable Diffusion's Predecessors and Components
AutoencoderKL [paper](https://arxiv.org/abs/1312.6114v11)
DDPM [paper](https://arxiv.org/abs/2006.11239) - [github](https://github.com/hojonathanho/diffusion)
CLIP [paper](https://arxiv.org/pdf/2103.00020.pdf) - [github](https://github.com/OpenAI/CLIP)
OpenClip [info](https://laion.ai/blog/large-openclip/) - [github](https://github.com/mlfoundations/open_clip)
LAION 5B [paper](https://arxiv.org/abs/2210.08402) - [datasets](https://huggingface.co/laion)
### Latent Diffusion
Latent Diffusion [paper](https://arxiv.org/abs/2112.10752) - [github](https://github.com/CompVis/latent-diffusion) -- Stable Diffusion [github](https://github.com/CompVis/stable-diffusion)
SDXL [paper](https://arxiv.org/abs/2307.01952) - [github](https://github.com/Stability-AI/generative-models)
### Captioning models
Open Flamingo [paper](https://arxiv.org/abs/2308.01390) - [github](https://github.com/mlfoundations/open_flamingo)
BLIP/BLIP2 [blip paper](https://arxiv.org/abs/2201.12086) - [blip2 github (LAVIS)](https://github.com/salesforce/LAVIS) - [blip1 github](https://github.com/salesforce/BLIP)
Kosmos-2 [paper](https://arxiv.org/abs/2306.14824) - [Github](https://github.com/microsoft/unilm/tree/master/kosmos-2) - [Huggingface](https://huggingface.co/microsoft/kosmos-2-patch14-224)
### Optimizers
Adam [paper](https://arxiv.org/abs/1412.6980)
8-bit block-wise quantization [paper](https://arxiv.org/abs/2110.02861) - [github](https://github.com/TimDettmers/bitsandbytes)
D-Adaptation [paper](https://arxiv.org/abs/2301.07733) - [github](https://github.com/facebookresearch/dadaptation)
DoWG [paper](https://arxiv.org/abs/2305.16284)

View File

@ -9,7 +9,7 @@ aiohttp==3.8.4
tensorboard>=2.11.0
protobuf==3.20.1
pyre-extensions==0.0.29
xformers==0.0.20
xformers==0.0.22.post7
pytorch-lightning==1.6.5
OmegaConf==2.2.3
numpy==1.23.5

View File

@ -189,7 +189,7 @@ def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, s
pipeline_ema.save_pretrained(diffusers_model_path)
if save_ckpt:
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.safetensors"
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
@ -210,7 +210,7 @@ def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, s
pipeline.save_pretrained(diffusers_model_path)
if save_ckpt:
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
sd_ckpt_path = f"{os.path.basename(save_path)}.safetensors"
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
if save_optimizer_flag:
@ -223,17 +223,15 @@ def setup_local_logger(args):
configures logger with file and console logging, logs args, and returns the datestamp
"""
log_path = args.logdir
if not os.path.exists(log_path):
os.makedirs(log_path)
json_config = json.dumps(vars(args), indent=2)
os.makedirs(log_path, exist_ok=True)
datetimestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
with open(os.path.join(log_path, f"{args.project_name}-{datetimestamp}_cfg.json"), "w") as f:
f.write(f"{json_config}")
log_folder = os.path.join(log_path, f"{args.project_name}-{datetimestamp}")
os.makedirs(log_folder, exist_ok=True)
logfilename = os.path.join(log_folder, f"{args.project_name}-{datetimestamp}.log")
logfilename = os.path.join(log_path, f"{args.project_name}-{datetimestamp}.log")
print(f" logging to {logfilename}")
logging.basicConfig(filename=logfilename,
level=logging.INFO,
@ -247,7 +245,7 @@ def setup_local_logger(args):
warnings.filterwarnings("ignore", message="UserWarning: Palette images with Transparency expressed in bytes should be converted to RGBA images")
#from PIL import Image
return datetimestamp
return datetimestamp, log_folder
# def save_optimizer(optimizer: torch.optim.Optimizer, path: str):
# """
@ -473,15 +471,14 @@ def resolve_image_train_items(args: argparse.Namespace) -> list[ImageTrainItem]:
return image_train_items
def write_batch_schedule(args: argparse.Namespace, log_folder: str, train_batch: EveryDreamBatch, epoch: int):
if args.write_schedule:
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
for i in range(len(train_batch.image_train_items)):
try:
item = train_batch.image_train_items[i]
f.write(f"step:{int(i / train_batch.batch_size):05}, wh:{item.target_wh}, r:{item.runt_size}, path:{item.pathname}\n")
except Exception as e:
logging.error(f" * Error writing to batch schedule for file path: {item.pathname}")
def write_batch_schedule(log_folder: str, train_batch: EveryDreamBatch, epoch: int):
with open(f"{log_folder}/ep{epoch}_batch_schedule.txt", "w", encoding='utf-8') as f:
for i in range(len(train_batch.image_train_items)):
try:
item = train_batch.image_train_items[i]
f.write(f"step:{int(i / train_batch.batch_size):05}, wh:{item.target_wh}, r:{item.runt_size}, path:{item.pathname}\n")
except Exception as e:
logging.error(f" * Error writing to batch schedule for file path: {item.pathname}")
def read_sample_prompts(sample_prompts_file_path: str):
@ -491,12 +488,22 @@ def read_sample_prompts(sample_prompts_file_path: str):
sample_prompts.append(line.strip())
return sample_prompts
def log_args(log_writer, args):
def log_args(log_writer, args, optimizer_config, log_folder, log_time):
arglog = "args:\n"
for arg, value in sorted(vars(args).items()):
arglog += f"{arg}={value}, "
log_writer.add_text("config", arglog)
args_as_json = json.dumps(vars(args), indent=2)
with open(os.path.join(log_folder, f"{args.project_name}-{log_time}_main.json"), "w") as f:
f.write(args_as_json)
optimizer_config_as_json = json.dumps(optimizer_config, indent=2)
with open(os.path.join(log_folder, f"{args.project_name}-{log_time}_opt.json"), "w") as f:
f.write(optimizer_config_as_json)
def update_ema(model, ema_model, decay, default_device, ema_device):
with torch.no_grad():
original_model_on_proper_device = model
@ -574,7 +581,7 @@ def main(args):
print(" * Windows detected, disabling Triton")
os.environ['XFORMERS_FORCE_DISABLE_TRITON'] = "1"
log_time = setup_local_logger(args)
log_time, log_folder = setup_local_logger(args)
args = setup_args(args)
print(f" Args:")
pprint.pprint(vars(args))
@ -593,8 +600,7 @@ def main(args):
device = 'cpu'
gpu = None
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
#log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
if not os.path.exists(log_folder):
os.makedirs(log_folder)
@ -717,8 +723,6 @@ def main(args):
text_encoder = text_encoder.to(device, dtype=torch.float32)
if use_ema_dacay_training:
if not ema_model_loaded_from_file:
logging.info(f"EMA decay enabled, creating EMA model.")
@ -835,7 +839,7 @@ def main(args):
epoch_len,
log_writer)
log_args(log_writer, args)
log_args(log_writer, args, optimizer_config, log_folder, log_time)
sample_generator = SampleGenerator(log_folder=log_folder, log_writer=log_writer,
default_resolution=args.resolution, default_seed=args.seed,
@ -869,7 +873,6 @@ def main(args):
if not interrupted:
interrupted=True
global global_step
#TODO: save model on ctrl-c
interrupted_checkpoint_path = os.path.join(f"{log_folder}/ckpts/interrupted-gs{global_step}")
print()
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
@ -1121,12 +1124,11 @@ def main(args):
text_encoder_ema=text_encoder_ema)
epoch = None
try:
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
try:
plugin_runner.run_on_training_start(log_folder=log_folder, project_name=args.project_name)
for epoch in range(args.max_epochs):
write_batch_schedule(log_folder, train_batch, epoch) if args.write_schedule else None
if args.load_settings_every_epoch:
load_train_json_from_file(args)
@ -1287,7 +1289,6 @@ def main(args):
epoch_pbar.update(1)
if epoch < args.max_epochs - 1:
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
if len(loss_epoch) > 0:
loss_epoch = sum(loss_epoch) / len(loss_epoch)