Merge branch 'victorchall:main' into feat_add_sde_samplers

This commit is contained in:
Damian Stewart 2023-10-23 00:04:03 +02:00 committed by GitHub
commit bc1058a0d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 860 additions and 145 deletions

View File

@ -124,7 +124,8 @@
" 'wandb',\n",
" 'colorama',\n",
" 'keyboard',\n",
" 'lion-pytorch'\n",
" 'lion-pytorch',\n",
" 'safetensors'\n",
"]\n",
"\n",
"print(colored(0, 255, 0, 'Installing packages...'))\n",
@ -543,7 +544,7 @@
" --resolution $Resolution \\\n",
" --sample_prompts \"$Sample_File\" \\\n",
" --sample_steps $Steps_between_samples \\\n",
" --save_every_n_epoch $Save_every_N_epoch \\\n",
" --save_every_n_epochs $Save_every_N_epoch \\\n",
" --seed $Training_Seed \\\n",
" --zero_frequency_noise_ratio $zero_frequency_noise\n",
"\n",

View File

@ -41,6 +41,7 @@ class EveryDreamBatch(Dataset):
seed=555,
tokenizer=None,
shuffle_tags=False,
keep_tags=0,
rated_dataset=False,
rated_dataset_dropout_target=0.5,
name='train'
@ -54,6 +55,7 @@ class EveryDreamBatch(Dataset):
self.tokenizer = tokenizer
self.max_token_length = self.tokenizer.model_max_length
self.shuffle_tags = shuffle_tags
self.keep_tags = keep_tags
self.seed = seed
self.rated_dataset = rated_dataset
self.rated_dataset_dropout_target = rated_dataset_dropout_target
@ -94,7 +96,7 @@ class EveryDreamBatch(Dataset):
)
if self.shuffle_tags or train_item["shuffle_tags"]:
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed, keep_tags=self.keep_tags)
else:
example["caption"] = train_item["caption"].get_caption()
@ -138,7 +140,59 @@ class EveryDreamBatch(Dataset):
def __update_image_train_items(self, dropout_fraction: float):
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
class DataLoaderWithFixedBuffer(torch.utils.data.DataLoader):
def __init__(self, dataset, buffer_tensor, batch_size:int, max_pixels: int, buffer_dtype: torch.dtype, device="cuda"):
color_channels = 3
buffer_size = batch_size * color_channels * max_pixels
self.buffer_size = buffer_size
buffer_tensor = torch.empty(buffer_size, dtype=buffer_dtype, device=device).pin_memory()
self.buffer_tensor = buffer_tensor
logging.info(f"buffer_tensor created with shape: {buffer_tensor.shape}")
super().__init__(dataset, batch_size=batch_size, shuffle=False, num_workers=min(batch_size, os.cpu_count()), collate_fn=self.fixed_collate_fn)
def fixed_collate_fn(self, batch):
"""
Collates images to a pinned buffer returned as a view using actual resolution shape
"""
images = [example["image"] for example in batch]
# map the image data to the fixed buffer view
w, h = images[0].size
for i in range(self.batch_size):
image = batch["image"][i]
self.buffer_tensor[i*self.buffer_size//self.batch_size:(i+1)*self.buffer_size//self.batch_size] = image.view(-1)
images = self.buffer_tensor.view(self.batch_size, 3, w, h)
captions = [example["caption"] for example in batch]
tokens = [example["tokens"] for example in batch]
runt_size = batch[0]["runt_size"]
images = torch.stack(images)
images = images.to(memory_format=torch.contiguous_format).float()
ret = {
"tokens": torch.stack(tuple(tokens)),
"image": images,
"captions": captions,
"runt_size": runt_size,
}
del batch
return ret
def build_torch_dataloader2(dataset, batch_size, max_pixels) -> torch.utils.data.DataLoader:
dataloader = DataLoaderWithFixedBuffer(
dataset,
max_pixels=max_pixels,
batch_size=batch_size,
shuffle=False,
num_workers=min(batch_size, os.cpu_count()),
collate_fn=collate_fn
)
return dataloader
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size= batch_size,
@ -148,7 +202,6 @@ def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
)
return dataloader
def collate_fn(batch):
"""
Collates batches

View File

@ -56,7 +56,7 @@ class ImageCaption:
def rating(self) -> float:
return self.__rating
def get_shuffled_caption(self, seed: int) -> str:
def get_shuffled_caption(self, seed: int, keep_tags: int) -> str:
"""
returns the caption a string with a random selection of the tags in random order
:param seed used to initialize the randomizer
@ -74,7 +74,7 @@ class ImageCaption:
if self.__use_weights:
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
else:
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
tags_caption = self.__get_shuffled_tags(seed, self.__tags, keep_tags)
return self.__main_prompt + ", " + tags_caption
return self.__main_prompt
@ -111,8 +111,16 @@ class ImageCaption:
return caption
@staticmethod
def __get_shuffled_tags(seed: int, tags: list[str]) -> str:
random.Random(seed).shuffle(tags)
def __get_shuffled_tags(seed: int, tags: list[str], keep_tags: int) -> str:
tags = tags.copy()
keep_tags = min(keep_tags, 0)
if len(tags) > keep_tags:
fixed_tags = tags[:keep_tags]
rest = tags[keep_tags:]
random.Random(seed).shuffle(rest)
tags = fixed_tags + rest
return ", ".join(tags)
class ImageTrainItem:
@ -306,8 +314,10 @@ class ImageTrainItem:
image_aspect = width / height
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
self.is_undersized = (width * height) < (target_wh[0]*1.02 * target_wh[1]*1.02)
self.is_undersized = (width != target_wh[0] and height != target_wh[1]) and (width * height) < (target_wh[0]*1.02 * target_wh[1]*1.02)
self.target_wh = target_wh
self.image_size = image.size
except Exception as e:
self.error = e

View File

@ -6,7 +6,7 @@ Start with the [Low VRAM guide](TWEAKING.md) if you are having trouble training
## Resolution
You can train resolutions from 512 to 1024 in 64 pixel increments. General results from the community indicate you can push the base model a bit beyond what it was designed for *with enough training*. This will work out better when you have a lot of training data (hundreds+) and enable slightly higher resolution at inference time without seeing repeats in your generated images. This does cost speed of training and higher VRAM use! Ex. 768 takes a significant amount more VRAM than 512, so you will need to compensate for that by reducing ```batch_size```.
You can train resolutions from 512 to 1024 in 64 pixel increments. General results from the community indicate you can push the base model a bit beyond what it was designed for *with enough training*. This will work out better when you have a lot of training data (hundreds+) and enable slightly higher resolution at inference time without seeing repeats in your generated images. This does cost speed of training and higher VRAM use! Ex. 768 takes a significant amount of additional VRAM than 512, so you will need to compensate for that by reducing ```batch_size```.
--resolution 640 ^
@ -14,21 +14,21 @@ For instance, if training from the base 1.5 model, you can try trying at 576, 64
If you are training on a base model that is 768, such as "SD 2.1 768-v", you should also probably use 768 as a base number and adjust from there.
Some results from the community seem to indicate training at a higher resolution on SD1.x models may increase how fast the model learns, and it may be a good idea to slightly reduce your learning rate as you increase resolution. My suspcision is that the higher resolutions increase the gradients as more information is presented to the model per image.
Some results from the community seem to indicate training at a higher resolution on SD1.x models may increase how fast the model learns, and it may be a good idea to slightly reduce your learning rate as you increase resolution. My suspicion is that the higher resolutions increase the gradients as more information is presented to the model per image.
You may need to experiment with LR as you increase resolution. I don't have a perfect rule of thumb here, but I might suggest if you train SD1.5 which is a 512 model at resolution 768 you reduce your LR by about half. ED2 tends to prefer ~2e-6 to ~5e-6 for normal 512 training on SD1.X models around batch 6-8, so if you train SD1.X at 768 consider 1e-6 to 2.5e-6 instead.
You may need to experiment with the LR as you increase resolution. I don't have a perfect rule of thumb here, but I might suggest if you train SD1.5 which is a 512 model at resolution 768 you reduce your LR by about half. ED2 tends to prefer ~2e-6 to ~5e-6 for normal 512 training on SD1.X models around batch 6-8, so if you train SD1.X at 768 consider 1e-6 to 2.5e-6 instead.
## Log and ckpt save folders
If you want to use a nondefault location for saving logs or ckpt files, these:
Logdir defaults to the "logs" folder in the trainer directory. If you wan to save all logs (including diffuser copies of ckpts, sample images, and tensbooard events) use this:
Logdir defaults to the "logs" folder in the trainer directory. If you want to save all logs (including diffuser copies of ckpts, sample images, and tensbooard events) use this:
--logdir "/workspace/mylogs"
Remember to use the same folder when you launch tensorboard (```tensorboard --logdir "/worksapce/mylogs"```) or it won't find your logs.
By default the CKPT format copies of ckpts that are peroidically saved are saved in the trainer root folder. If you want to save them elsewhere, use this:
By default the CKPT format copies of ckpts that are periodically saved are saved in the trainer root folder. If you want to save them elsewhere, use this:
--save_ckpt_dir "r:\webui\models\stable-diffusion"
@ -125,11 +125,11 @@ Seed can be used to make training either more or less deterministic. The seed v
To use a random seed, use -1:
-- seed -1
--seed -1
Default behavior is to use a fixed seed of 555. The seed you set is fixed for all samples if you set a value other than -1. If you set a seed it is also incrememted for shuffling your training data every epoch (i.e. 555, 556, 557, etc). This makes training more deterministic. I suggest a fixed seed when you are trying A/B test tweaks to your general training setup, or when you want all your test samples to use the same seed.
Fixed seed should be using when performing A/B tests or hyperparameter sweeps. Random seed (-1) may be better if you are stopping and resuming training often so every restart is using random values for all of the various randomness sources used in training such as noising and data shuffling.
Fixed seed should be used when performing A/B tests or hyperparameter sweeps. Random seed (-1) may be better if you are stopping and resuming training often so every restart is using random values for all of the various randomness sources used in training such as noising and data shuffling.
## Shuffle tags
@ -139,6 +139,12 @@ For those training booru tagged models, you can use this arg to randomly (but de
This simply chops the captions in to parts based on the commas and shuffles the order.
In case you want to keep static the first N tags, you can also add this parameter (`--shuffle_tags` must also be set):
--keep_tags 4 ^
The above example will keep static the 4 first additional tags, and shuffle the rest.
## Zero frequency noise
Based on [Nicholas Guttenberg's blog post](https://www.crosslabs.org//blog/diffusion-with-offset-noise) zero frequency noise offsets the noise added to the image during training/denoising, which can help improve contrast and the ability to render very dark or very bright scenes more accurately, and may help slightly with color saturation.
@ -149,7 +155,7 @@ Based on [Nicholas Guttenberg's blog post](https://www.crosslabs.org//blog/diffu
Test results: https://huggingface.co/panopstor/ff7r-stable-diffusion/blob/main/zero_freq_test_biggs.webp
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time.
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergence over time.
## Zero terminal SNR
@ -205,3 +211,67 @@ Clips the gradient normals to a maximum value. Default is None (no clipping).
Default is no gradient normal clipping. There are also other ways to deal with gradient explosion, such as increasing optimizer epsilon.
## Zero Terminal SNR
**Parameter:** `--enable_zero_terminal_snr`
**Default:** `False`
To enable zero terminal SNR.
## Dynamic Configuration Loading
**Parameter:** `--load_settings_every_epoch`
**Default:** `False`
Most of the parameters in the train.json file CANNOT be modified during training. Activate this to have the `train.json` configuration file reloaded at the start of each epoch. The following parameter can be changed and will be applied after the start of a new epoch:
- `--save_every_n_epochs`
- `--save_ckpts_from_n_epochs`
- `--save_full_precision`
- `--save_optimizer`
- `--zero_frequency_noise_ratio`
- `--min_snr_gamma`
- `--clip_skip`
## Min-SNR-Gamma Parameter
**Parameter:** `--min_snr_gamma`
**Recommended Values:** 5, 1, 20
**Default:** `None`
To enable min-SNR-Gamma. For an in-depth understanding, consult this [research paper](https://arxiv.org/abs/2303.09556).
## EMA Decay Features
The Exponential Moving Average (EMA) model is copied from the base model at the start and is updated every interval of steps by a small contribution from training.
In this mode, the EMA model will be saved alongside the regular checkpoint from training. Normal training checkpoint can be loaded with `--resume_ckpt`, and the EMA model can be loaded with `--ema_decay_resume_model`.
For more information, consult the [research paper](https://arxiv.org/abs/2101.08482) or continue reading the tuning notes below.
**Parameters:**
- `--ema_decay_rate`: Determines the EMA decay rate. It defines how much the EMA model is updated from training at each update. Values should be close to 1 but not exceed it. Activating this parameter triggers the EMA decay feature.
- `--ema_strength_target`: Set the EMA strength target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA feature. See [ema_strength_target](#ema_strength_target) for more information.
- `--ema_update_interval`: Set the interval in steps between EMA updates. The update occurs at each optimizer step. If you use grad_accum, actual update interval will be multipled by your grad_accum value.
- `--ema_device`: Choose between `cpu` and `cuda` for EMA. Opting for 'cpu' takes around 4 seconds per update and uses approximately 3.2GB RAM, while 'cuda' is much faster but requires a similar amount of VRAM.
- `--ema_sample_nonema_model`: Activate to display samples from the non-ema trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled.
- `--ema_sample_ema_model`: Turn on to exhibit samples from the EMA model. EMA models will be used for samples generations by default with EMA decay enabled, unless disabled.
- `--ema_resume_model`: Indicate the EMA decay checkpoint to continue from, working like `--resume_ckpt` but will load EMA model. Using `findlast` will only load EMA version and not regular training.
## Notes on tuning EMA.
The purpose of EMA is to reduce the effect of the data from the tail end of training from having an overly powerful effect on the model. Normally trainig is stopped abruptly and the final images seen by the trainer may have a stronger effect than images seen earlier in training. *This may have a similar to lowering the learning rate near the end of training, but is not mathematically equivalent.* An alternative method to EMA would be to use a cosine learning rate schedule.
Training with EMA turned on has no effect on the non-EMA model if all other settings are identical, though practical considerations (mainly VRAM limits) may cause you to change other settings which can affect the non-EMA model, such as lowering batch size to free enough VRAM for the EMA model if using gpu.
A standard implementation of EMA uses a decay rate of 0.9999, GPU device, and an interval of 1 (every optimizer step). This value can have a strong effect, leading to what appears to be an undertrained EMA model compared to the non-EMA model. A value of 0.999 seems to produce an EMA model nearly identical to the non-EMA model and should be considered a low value. Somewhere in the 0.999-0.9999 range is suggested when using GPU and interval 1.
EMA uses an additional ~3.2GB of RAM (for SD1.x models) to store an extra copy of the model weights in memory. For even 24GB consumer GPUs this is substantial, but EMA CPU offloading together with using a higher `ema_update_interval` can make it more practical. It can be practical on a 24GB GPU if you also enable gradient checkpointing, which is not normally suggested for 24GB GPUs as it is not necessary. Gradient checkpointing saves a bit more than 3.2GB itself. The other options is to use CPU offloading by setting `ema_device: "cpu"`. The EMA copy of the model will be stored in main system memory instead of the GPU, but at a cost of slower sampling and updating. CPU offloading is a requirement for GPUs with 16GB or less VRAM, and even 24GB GPU users may wish to consider it. If you are using a 40GB+ GPU you should use GPU.
When using a higher interval to make cpu offloading practical and reasonably fast, the decay rate should be lowered. For instance, with an interval of 50, you may wish to lower the decay rate to 0.99 or possibly lower. This is because the EMA model is updated less frequently and the decay rate is effectively higher than the set value under otherwise "normal" EMA training regime. The higher interval also reduces accuracy of the EMA model compared to the reference implementation which would normally update EMA every optimizer step.
I would suggest you pick an interval and stick with it, and then tune your decay_rate by generating samples from both EMA and non EMA using the options or after training using your favorite inference app and compare the results.
It is expected the EMA model will look "behind" on training, but should still be recognizable as the same subject matter. If it is not, you may wish to try a lower decay rate. If it is too close to the non-EMA model, you may wish to try a higher decay rate.
Using the GPU for ema incurs only a small speed penalty of around 5-10% with all else being equal, though if you change other parameters such as lowering batch size or enabling gradient checkpointing flag to free VRAM for EMA those options may incur a slightly higher speed penalities.
Generally, I recommend picking a device and approriate interval given your device choice first and stick with those values, then tweak the `ema_decay_rate` up or down according to how you want the EMA model to look vs. your non-EMA model. From there, if your EMA model seems to "lag behind" the non-EMA model by "too much" (subjectively judged), you can decrease decay rate. If it identical or nearly identical, use a slightly higher value.
## ema_strength_target
This arg is a non-standard way of calculating the actual decay rate used. It attempts to calculate a value for decay rate based on your `ema_update_interval` and the total length of training, compensating for both. Values of 0.01-0.15 should work, with higher values leading to a EMA model that deviates more from the non-EMA model similar to how decay rate works. It attempts to be more of a "strength" value of EMA, or "how much" (as a factor, i.e. 0.10 = 10% "strength") of the EMA model are kept for the totality of training.
While the calculation makes sense in how it compensates for inteval and total training length, it is not a standard way of calculating decay rate and there will not be information online about how to use it. I recommend not using this feature and instead picking a device and approriate interval given your device choice first, then tuning your decay rate by hand, find "good" values, then don't mess with them, but you can try this feature out if you want.
--ema_strength_target 0.10 ^
If you use `ema_strength_target` the actual calculated `ema_decay_rate` used will be printed in your logs, and you should pay attention to this value and use it to inform your future decisions on EMA tuning.

View File

@ -4,7 +4,7 @@
`python caption_fl.py --data_root input --min_new_tokens 20 --max_new_tokens 30 --num_beams 3 --model "openflamingo/OpenFlamingo-9B-vitl-mpt7b"`
This script uses two example image/caption pairs located in the `/example` folder to prime the system to caption, then captions the images in the input folder. It will save a `.txt` file of the same base filename with the captoin in the same folder.
This script uses two example image/caption pairs located in the `/example` folder to prime the system to caption, then captions the images in the input folder. It will save a `.txt` file of the same base filename with the caption in the same folder.
This script currently requires an AMPERE or newer GPU due to using bfloat16.

View File

@ -29,12 +29,13 @@ For each of the `unet` and `text_encoder` sections, you can set the following pr
Standard full precision AdamW optimizer exposed by PyTorch. Not recommended. Slower and uses more memory than adamw8bit. Widely documented on the web.
* adamw8bit
* lion8bit
Tim Dettmers / bitsandbytes AdamW 8bit optimizer. This is the default and recommended setting. Widely documented on the web.
Tim Dettmers / bitsandbytes AdamW and Lion 8bit optimizer. adamw8bit is the default and recommended setting as it is well understood, and lion8bit is very vram efficient. Widely documented on the web.
* lion
Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the [lion optimizer](https://arxiv.org/abs/2302.06675). Click links to read more. `Epsilon` is not used by lion.
Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the [lion optimizer](https://arxiv.org/abs/2302.06675). Click links to read more. `Epsilon` is not used by lion. You should prefer lion8bit over this optimizer as it is more memory efficient.
Recommended settings for lion based on the paper are as follows:
@ -61,7 +62,13 @@ Available optimizer values for Dadaptation are:
* dadapt_lion, dadapt_adam, dadapt_sgd
These are fairly experimental but tested as working. Gradient checkpointing may be required even on 24GB GPUs. Performance is slower than the compiled and optimized AdamW8bit optimizer unless you increae gradient accumulation as it seems the accumulation steps process slowly with the current implementation of D-Adaption
These are fairly experimental but tested as working. Gradient checkpointing may be required even on 24GB GPUs. Performance is slower than the compiled and optimized AdamW8bit optimizer unless you increae gradient accumulation as it seems the accumulation steps process slowly with the current implementation of D-Adaption.
#### Prodigy
Another adaptive optimizer. It is not very VRAM efficient. [Github](https://github.com/konstmish/prodigy), [Paper](https://arxiv.org/pdf/2306.06101.pdf)
* prodigy
## Optimizer parameters

View File

@ -65,3 +65,4 @@ The effect of the limit is that the caption will always be truncated when the ma
exceeded. This process does not consider if the cutoff is in the middle of a tag or even in the middle of a
word if it is translated into several tokens.
To mitigate this token limitation (when not using weighted shuffling), the `--keep_tags n` parameter can be employed. This ensures that the first n tags following the initial chunk remain static, while the remaining tags are shuffled.

View File

@ -1,5 +1,6 @@
aiohttp==3.8.4
bitsandbytes==0.38.1
bitsandbytes==0.41.1
scipy
colorama==0.4.6
compel~=1.1.3
ftfy==6.1.1
@ -14,4 +15,5 @@ pynvml==11.5.0
speedtest-cli
tensorboard==2.12.0
wandb
safetensors
safetensors
prodigyopt

View File

@ -276,6 +276,7 @@ class EveryDreamOptimizer():
decouple = True # seems bad to turn off, dadapt_adam only
momentum = 0.0 # dadapt_sgd
no_prox = False # ????, dadapt_adan
use_bias_correction = True # suggest by prodigy github
growth_rate=float("inf") # dadapt various, no idea what a sane default is
if local_optimizer_config is not None:
@ -307,6 +308,30 @@ class EveryDreamOptimizer():
betas=(betas[0], betas[1]),
weight_decay=weight_decay,
)
elif optimizer_name == "lion8bit":
from bitsandbytes.optim import Lion8bit
opt_class = Lion8bit
optimizer = opt_class(
itertools.chain(parameters),
lr=curr_lr,
betas=(betas[0], betas[1]),
weight_decay=weight_decay,
percentile_clipping=100,
min_8bit_size=4096,
)
elif optimizer_name == "prodigy":
from prodigyopt import Prodigy
opt_class = Prodigy
safeguard_warmup = True # per recommendation from prodigy documentation
optimizer = opt_class(
itertools.chain(parameters),
lr=curr_lr,
weight_decay=weight_decay,
use_bias_correction=use_bias_correction,
growth_rate=growth_rate,
d0=d0,
safeguard_warmup=safeguard_warmup
)
elif optimizer_name == "adamw":
opt_class = torch.optim.AdamW
if "dowg" in optimizer_name:
@ -317,7 +342,7 @@ class EveryDreamOptimizer():
elif optimizer_name == "scalar_dowg":
opt_class = dowg.ScalarDoWG
else:
raise ValueError(f"Unknown DoWG optimizer {optimizer_name}. Available options are coordinate_dowg and scalar_dowg")
raise ValueError(f"Unknown DoWG optimizer {optimizer_name}. Available options are 'coordinate_dowg' and 'scalar_dowg'")
elif optimizer_name in ["dadapt_adam", "dadapt_lion", "dadapt_sgd"]:
import dadaptation

56
plugins/interruptible.py Normal file
View File

@ -0,0 +1,56 @@
import math
import os
import shutil
from plugins.plugins import BasePlugin
from train import save_model
EVERY_N_EPOCHS = 1 # how often to save. integers >= 1 save at the end of every nth epoch. floats < 1 subdivide the epoch evenly (eg 0.33 = 3 subdivisions)
class InterruptiblePlugin(BasePlugin):
def __init__(self):
print("Interruptible plugin instantiated")
self.previous_save_path = None
self.every_n_epochs = EVERY_N_EPOCHS
def on_epoch_start(self, **kwargs):
epoch = kwargs['epoch']
epoch_length = kwargs['epoch_length']
self.steps_to_save_this_epoch = self._get_save_step_indices(epoch, epoch_length)
def on_step_end(self, **kwargs):
local_step = kwargs['local_step']
if local_step in self.steps_to_save_this_epoch:
global_step = kwargs['global_step']
epoch = kwargs['epoch']
project_name = kwargs['project_name']
log_folder = kwargs['log_folder']
ckpt_name = f"rolling-{project_name}-ep{epoch:02}-gs{global_step:05}"
save_path = os.path.join(log_folder, "ckpts", ckpt_name)
print(f"{type(self)} saving model to {save_path}")
save_model(save_path, global_step=global_step, ed_state=kwargs['ed_state'], save_ckpt_dir=None, yaml_name=None, save_ckpt=False, save_full_precision=True, save_optimizer_flag=True)
self._remove_previous()
self.previous_save_path = save_path
def on_training_end(self, **kwargs):
self._remove_previous()
def _remove_previous(self):
if self.previous_save_path is not None:
shutil.rmtree(self.previous_save_path, ignore_errors=True)
self.previous_save_path = None
def _get_save_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
if self.every_n_epochs >= 1:
if ((epoch+1) % self.every_n_epochs) == 0:
# last step only
return [epoch_length_steps-1]
else:
return []
else:
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
# validation happens after training:
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
validate_every_n_steps = epoch_length_steps / num_divisions
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]

View File

@ -44,7 +44,7 @@ class Timer:
def __exit__(self, type, value, traceback):
elapsed_time = time.time() - self.start
if elapsed_time > self.warn_seconds:
logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.limit} seconds')
logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.warn_seconds} seconds')
class PluginRunner:

View File

@ -3,7 +3,7 @@ torchvision==0.15.2
transformers==4.29.2
diffusers[torch]==0.18.0
pynvml==11.4.1
bitsandbytes==0.38.1
bitsandbytes==0.41.1
ftfy==6.1.1
aiohttp==3.8.4
tensorboard>=2.11.0
@ -21,4 +21,4 @@ numpy==1.23.5
wandb
colorama
safetensors
open-flamingo==2.0.0
open-flamingo==2.0.0

View File

@ -4,6 +4,7 @@ import pathlib
import PIL.Image as Image
from data.image_train_item import ImageCaption, ImageTrainItem
import data.aspects as aspects
DATA_PATH = pathlib.Path('./test/data')
@ -32,4 +33,70 @@ class TestImageCaption(unittest.TestCase):
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
self.assertEqual(caption.get_caption(), "hello world")
self.assertEqual(caption.get_caption(), "hello world")
class TestImageTrainItemConstructor(unittest.TestCase):
def tearDown(self) -> None:
for file in DATA_PATH.glob("img_*"):
file.unlink()
return super().tearDown()
@staticmethod
def image_with_size(width, height):
filename = DATA_PATH / "img_{}x{}.jpg".format(width, height)
Image.new("RGB", (width, height)).save(filename)
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
return ImageTrainItem(None, caption, aspects.ASPECTS_512, filename, 0.0, 1.0, False, False, 0)
def test_target_size_computation(self):
# Square images
image = self.image_with_size(30, 30)
self.assertEqual(image.target_wh, [512,512])
self.assertTrue(image.is_undersized)
self.assertEqual(image.image_size, (30,30))
image = self.image_with_size(512, 512)
self.assertEqual(image.target_wh, [512,512])
self.assertFalse(image.is_undersized)
self.assertEqual(image.image_size, (512,512))
image = self.image_with_size(580, 580)
self.assertEqual(image.target_wh, [512,512])
self.assertFalse(image.is_undersized)
self.assertEqual(image.image_size, (580,580))
# Landscape images
image = self.image_with_size(64, 38)
self.assertEqual(image.target_wh, [640,384])
self.assertTrue(image.is_undersized)
self.assertEqual(image.image_size, (64,38))
image = self.image_with_size(640, 384)
self.assertEqual(image.target_wh, [640,384])
self.assertFalse(image.is_undersized)
self.assertEqual(image.image_size, (640,384))
image = self.image_with_size(704, 422)
self.assertEqual(image.target_wh, [640,384])
self.assertFalse(image.is_undersized)
self.assertEqual(image.image_size, (704,422))
# Portrait images
image = self.image_with_size(38, 64)
self.assertEqual(image.target_wh, [384,640])
self.assertTrue(image.is_undersized)
self.assertEqual(image.image_size, (38,64))
image = self.image_with_size(384, 640)
self.assertEqual(image.target_wh, [384,640])
self.assertFalse(image.is_undersized)
self.assertEqual(image.image_size, (384,640))
image = self.image_with_size(422, 704)
self.assertEqual(image.target_wh, [384,640])
self.assertFalse(image.is_undersized)
self.assertEqual(image.image_size, (422,704))

View File

@ -4,7 +4,7 @@
"clip_grad_norm": null,
"clip_skip": 0,
"cond_dropout": 0.04,
"data_root": "X:\\my_project_data\\project_abc",
"data_root": "/mnt/q/training_samples/ff7r/man",
"disable_amp": false,
"disable_textenc_training": false,
"disable_xformers": false,
@ -19,12 +19,12 @@
"lr_decay_steps": 0,
"lr_scheduler": "constant",
"lr_warmup_steps": null,
"max_epochs": 30,
"max_epochs": 1,
"notebook": false,
"optimizer_config": "optimizer.json",
"project_name": "project_abc",
"resolution": 512,
"resume_ckpt": "sd_v1-5_vae",
"resume_ckpt": "panopstor/EveryDream",
"run_name": null,
"sample_prompts": "sample_prompts.txt",
"sample_steps": 300,
@ -40,5 +40,15 @@
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.02
"zero_frequency_noise_ratio": 0.02,
"enable_zero_terminal_snr": false,
"load_settings_every_epoch": false,
"min_snr_gamma": null,
"ema_decay_rate": null,
"ema_strength_target": null,
"ema_update_interval": null,
"ema_device": null,
"ema_sample_nonema_model": false,
"ema_sample_ema_model": false,
"ema_resume_model" : null
}

579
train.py
View File

@ -27,6 +27,7 @@ import gc
import random
import traceback
import shutil
from typing import Optional
import torch.nn.functional as F
from torch.cuda.amp import autocast
@ -61,6 +62,7 @@ from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng
from utils.check_git import check_git
from optimizer.optimizers import EveryDreamOptimizer
from copy import deepcopy
if torch.cuda.is_available():
from utils.gpu import GPU
@ -101,6 +103,108 @@ def convert_to_hf(ckpt_path):
is_sd1attn, yaml = get_attn_yaml(ckpt_path)
return ckpt_path, is_sd1attn, yaml
class EveryDreamTrainingState:
def __init__(self,
optimizer: EveryDreamOptimizer,
train_batch: EveryDreamBatch,
unet: UNet2DConditionModel,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
scheduler,
vae: AutoencoderKL,
unet_ema: Optional[UNet2DConditionModel],
text_encoder_ema: Optional[CLIPTextModel]
):
self.optimizer = optimizer
self.train_batch = train_batch
self.unet = unet
self.text_encoder = text_encoder
self.tokenizer = tokenizer
self.scheduler = scheduler
self.vae = vae
self.unet_ema = unet_ema
self.text_encoder_ema = text_encoder_ema
@torch.no_grad()
def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, save_ckpt_dir, yaml_name,
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
"""
Save the model to disk
"""
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
nonlocal save_ckpt_dir
nonlocal save_full_precision
nonlocal yaml_name
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half)
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
if global_step is None or global_step == 0:
logging.warning(" No model to save, something likely blew up on startup, not saving")
return
if args.ema_decay_rate != None:
pipeline_ema = StableDiffusionPipeline(
vae=ed_state.vae,
text_encoder=ed_state.text_encoder_ema,
tokenizer=ed_state.tokenizer,
unet=ed_state.unet_ema,
scheduler=ed_state.scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
diffusers_model_path = save_path + "_ema"
logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}")
pipeline_ema.save_pretrained(diffusers_model_path)
if save_ckpt:
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
pipeline = StableDiffusionPipeline(
vae=ed_state.vae,
text_encoder=ed_state.text_encoder,
tokenizer=ed_state.tokenizer,
unet=ed_state.unet,
scheduler=ed_state.scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
diffusers_model_path = save_path
logging.info(f" * Saving diffusers model to {diffusers_model_path}")
pipeline.save_pretrained(diffusers_model_path)
if save_ckpt:
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
if save_optimizer_flag:
logging.info(f" Saving optimizer state to {save_path}")
ed_state.optimizer.save(save_path)
def setup_local_logger(args):
"""
configures logger with file and console logging, logs args, and returns the datestamp
@ -186,7 +290,7 @@ def set_args_12gb(args):
logging.info(" - Overiding resolution to max 512")
args.resolution = 512
def find_last_checkpoint(logdir):
def find_last_checkpoint(logdir, is_ema=False):
"""
Finds the last checkpoint in the logdir, recursively
"""
@ -196,6 +300,12 @@ def find_last_checkpoint(logdir):
for root, dirs, files in os.walk(logdir):
for file in files:
if os.path.basename(file) == "model_index.json":
if is_ema and (not root.endswith("_ema")):
continue
elif (not is_ema) and root.endswith("_ema"):
continue
curr_date = os.path.getmtime(os.path.join(root,file))
if last_date is None or curr_date > last_date:
@ -228,12 +338,20 @@ def setup_args(args):
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if (args.ema_resume_model != None) and (args.ema_resume_model == "findlast"):
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last EMA decay checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
args.ema_resume_model = find_last_checkpoint(args.logdir, is_ema=True)
if args.lowvram:
set_args_12gb(args)
if not args.shuffle_tags:
args.shuffle_tags = False
if not args.keep_tags:
args.keep_tags = 0
args.clip_skip = max(min(4, args.clip_skip), 0)
if args.useadam8bit:
@ -356,6 +474,74 @@ def log_args(log_writer, args):
arglog += f"{arg}={value}, "
log_writer.add_text("config", arglog)
def update_ema(model, ema_model, decay, default_device, ema_device):
with torch.no_grad():
original_model_on_proper_device = model
need_to_delete_original = False
if ema_device != default_device:
original_model_on_other_device = deepcopy(model)
original_model_on_proper_device = original_model_on_other_device.to(ema_device, dtype=model.dtype)
del original_model_on_other_device
need_to_delete_original = True
params = dict(original_model_on_proper_device.named_parameters())
ema_params = dict(ema_model.named_parameters())
for name in ema_params:
#ema_params[name].data.mul_(decay).add_(params[name].data, alpha=1 - decay)
ema_params[name].data = ema_params[name] * decay + params[name].data * (1.0 - decay)
if need_to_delete_original:
del(original_model_on_proper_device)
def compute_snr(timesteps, noise_scheduler):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
minimal_value = 1e-9
alphas_cumprod = noise_scheduler.alphas_cumprod
# Use .any() to check if any elements in the tensor are zero
if (alphas_cumprod[:-1] == 0).any():
logging.warning(
f"Alphas cumprod has zero elements! Resetting to {minimal_value}.."
)
alphas_cumprod[alphas_cumprod[:-1] == 0] = minimal_value
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
timesteps
].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
device=timesteps.device
)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR, first without epsilon
snr = (alpha / sigma) ** 2
# Check if the first element in SNR tensor is zero
if torch.any(snr == 0):
snr[snr == 0] = minimal_value
return snr
def load_train_json_from_file(args, report_load = False):
try:
if report_load:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
read_json = json.load(f)
args.__dict__.update(read_json)
except Exception as config_read:
print(f"Error on loading training config from {args.config}.")
def main(args):
"""
@ -384,57 +570,28 @@ def main(args):
device = 'cpu'
gpu = None
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
if not os.path.exists(log_folder):
os.makedirs(log_folder)
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
"""
Save the model to disk
"""
global global_step
if global_step is None or global_step == 0:
logging.warning(" No model to save, something likely blew up on startup, not saving")
return
logging.info(f" * Saving diffusers model to {save_path}")
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
pipeline.save_pretrained(save_path)
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
def release_memory(model_to_delete, original_device):
del model_to_delete
gc.collect()
if save_ckpt:
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
if 'cuda' in original_device.type:
torch.cuda.empty_cache()
half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_strength_target != None)
ema_model_loaded_from_file = False
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
if save_optimizer_flag:
logging.info(f" Saving optimizer state to {save_path}")
ed_optimizer.save(save_path)
if use_ema_dacay_training:
ema_device = torch.device(args.ema_device)
optimizer_state_path = None
try:
# check for a local file
hf_cache_path = get_hf_ckpt_cache_path(args.resume_ckpt)
@ -443,10 +600,6 @@ def main(args):
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
optimizer_state_path = os.path.join(args.resume_ckpt, "optimizer.pt")
if not os.path.exists(optimizer_state_path):
optimizer_state_path = None
else:
# try to download from HF using resume_ckpt as a repo id
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
@ -457,16 +610,52 @@ def main(args):
vae = pipe.vae
unet = pipe.unet
del pipe
if args.zero_frequency_noise_ratio == -1.0:
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
if use_ema_dacay_training and args.ema_resume_model:
print(f"Loading EMA model: {args.ema_resume_model}")
ema_model_loaded_from_file=True
hf_cache_path = get_hf_ckpt_cache_path(args.ema_resume_model)
if os.path.exists(hf_cache_path) or os.path.exists(args.ema_resume_model):
ema_model_root_folder, ema_is_sd1attn, ema_yaml = convert_to_hf(args.resume_ckpt)
text_encoder_ema = CLIPTextModel.from_pretrained(ema_model_root_folder, subfolder="text_encoder")
unet_ema = UNet2DConditionModel.from_pretrained(ema_model_root_folder, subfolder="unet")
else:
# try to download from HF using ema_resume_model as a repo id
ema_downloaded = try_download_model_from_hf(repo_id=args.ema_resume_model)
if ema_downloaded is None:
raise ValueError(
f"No local file/folder for ema_resume_model {args.ema_resume_model}, and no matching huggingface.co repo could be downloaded")
ema_pipe, ema_model_root_folder, ema_is_sd1attn, ema_yaml = ema_downloaded
text_encoder_ema = ema_pipe.text_encoder
unet_ema = ema_pipe.unet
del ema_pipe
# Make sure EMA model is on proper device, and memory released if moved
unet_ema_current_device = next(unet_ema.parameters()).device
if ema_device != unet_ema_current_device:
unet_ema_on_wrong_device = unet_ema
unet_ema = unet_ema.to(ema_device)
release_memory(unet_ema_on_wrong_device, unet_ema_current_device)
# Make sure EMA model is on proper device, and memory released if moved
text_encoder_ema_current_device = next(text_encoder_ema.parameters()).device
if ema_device != text_encoder_ema_current_device:
text_encoder_ema_on_wrong_device = text_encoder_ema
text_encoder_ema = text_encoder_ema.to(ema_device)
release_memory(text_encoder_ema_on_wrong_device, text_encoder_ema_current_device)
if args.enable_zero_terminal_snr:
# Use zero terminal SNR
from utils.unet_utils import enforce_zero_terminal_snr
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
else:
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
@ -503,6 +692,32 @@ def main(args):
else:
text_encoder = text_encoder.to(device, dtype=torch.float32)
if use_ema_dacay_training:
if not ema_model_loaded_from_file:
logging.info(f"EMA decay enabled, creating EMA model.")
with torch.no_grad():
if args.ema_device == device:
unet_ema = deepcopy(unet)
text_encoder_ema = deepcopy(text_encoder)
else:
unet_ema_first = deepcopy(unet)
text_encoder_ema_first = deepcopy(text_encoder)
unet_ema = unet_ema_first.to(ema_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema_first.to(ema_device, dtype=text_encoder.dtype)
del unet_ema_first
del text_encoder_ema_first
else:
# Make sure correct types are used for models
unet_ema = unet_ema.to(ema_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype)
else:
unet_ema = None
text_encoder_ema = None
try:
#unet = torch.compile(unet)
#text_encoder = torch.compile(text_encoder)
@ -566,6 +781,7 @@ def main(args):
tokenizer=tokenizer,
seed = seed,
shuffle_tags=args.shuffle_tags,
keep_tags=args.keep_tags,
rated_dataset=args.rated_dataset,
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
)
@ -574,6 +790,20 @@ def main(args):
epoch_len = math.ceil(len(train_batch) / args.batch_size)
if use_ema_dacay_training:
args.ema_update_interval = args.ema_update_interval * args.grad_accum
if args.ema_strength_target != None:
total_number_of_steps: float = epoch_len * args.max_epochs
total_number_of_ema_update: float = total_number_of_steps / args.ema_update_interval
args.ema_decay_rate = args.ema_strength_target ** (1 / total_number_of_ema_update)
logging.info(f"ema_strength_target is {args.ema_strength_target}, calculated ema_decay_rate will be: {args.ema_decay_rate}.")
logging.info(
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_update_interval: {args.ema_update_interval}, ema_device: {args.ema_device}.")
ed_optimizer = EveryDreamOptimizer(args,
optimizer_config,
text_encoder,
@ -588,7 +818,8 @@ def main(args):
batch_size=max(1,args.batch_size//2),
default_sample_steps=args.sample_steps,
use_xformers=is_xformers_available() and not args.disable_xformers,
use_penultimate_clip_layer=(args.clip_skip >= 2)
use_penultimate_clip_layer=(args.clip_skip >= 2),
guidance_rescale=0.7 if args.enable_zero_terminal_snr else 0
)
"""
@ -620,7 +851,9 @@ def main(args):
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
save_model(interrupted_checkpoint_path, global_step=global_step, ed_state=make_current_ed_state(),
save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision,
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
exit(_SIGTERM_EXIT_CODE)
else:
# non-main threads (i.e. dataloader workers) should exit cleanly
@ -668,7 +901,7 @@ def main(args):
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0):
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False):
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
@ -676,12 +909,13 @@ def main(args):
del pixel_values
latents = latents[0].sample() * 0.18215
if zero_frequency_noise_ratio > 0.0:
if zero_frequency_noise_ratio != None:
if zero_frequency_noise_ratio < 0:
zero_frequency_noise_ratio = 0
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
noise = torch.randn_like(latents) + zero_frequency_noise
else:
noise = torch.randn_like(latents)
bsz = latents.shape[0]
@ -712,9 +946,35 @@ def main(args):
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
return model_pred, target
if return_loss:
if args.min_snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
snr = compute_snr(timesteps, noise_scheduler)
mse_loss_weights = (
torch.stack(
[snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]
/ snr
)
mse_loss_weights[snr == 0] = 1.0
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
return model_pred, target, loss
else:
return model_pred, target
def generate_samples(global_step: int, batch):
nonlocal unet
nonlocal text_encoder
nonlocal unet_ema
nonlocal text_encoder_ema
with isolate_rng():
prev_sample_steps = sample_generator.sample_steps
sample_generator.reload_config()
@ -723,19 +983,78 @@ def main(args):
print(f" * SampleGenerator config changed, now generating images samples every " +
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=reference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del inference_pipe
gc.collect()
models_info = []
if (args.ema_decay_rate is None) or args.ema_sample_nonema_model:
models_info.append({"is_ema": False, "swap_required": False})
if (args.ema_decay_rate is not None) and args.ema_sample_ema_model:
models_info.append({"is_ema": True, "swap_required": ema_device != device})
for model_info in models_info:
extra_info: str = ""
if model_info["is_ema"]:
current_unet, current_text_encoder = unet_ema, text_encoder_ema
extra_info = "_ema"
else:
current_unet, current_text_encoder = unet, text_encoder
torch.cuda.empty_cache()
if model_info["swap_required"]:
with torch.no_grad():
unet_unloaded = unet.to(ema_device)
del unet
text_encoder_unloaded = text_encoder.to(ema_device)
del text_encoder
current_unet = unet_ema.to(device)
del unet_ema
current_text_encoder = text_encoder_ema.to(device)
del text_encoder_ema
gc.collect()
torch.cuda.empty_cache()
inference_pipe = sample_generator.create_inference_pipe(unet=current_unet,
text_encoder=current_text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=inference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info)
# Cleanup
del inference_pipe
if model_info["swap_required"]:
with torch.no_grad():
unet = unet_unloaded.to(device)
del unet_unloaded
text_encoder = text_encoder_unloaded.to(device)
del text_encoder_unloaded
unet_ema = current_unet.to(ema_device)
del current_unet
text_encoder_ema = current_text_encoder.to(ema_device)
del current_text_encoder
gc.collect()
torch.cuda.empty_cache()
def make_save_path(epoch, global_step, prepend=""):
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
basename = f"{prepend}{args.project_name}"
if epoch is not None:
basename += f"-ep{epoch:02}"
if global_step is not None:
basename += f"-gs{global_step:05}"
return os.path.join(log_folder, "ckpts", basename)
# Pre-train validation to establish a starting point on the loss graph
if validator:
@ -753,25 +1072,47 @@ def main(args):
else:
logging.info("No plugins specified")
plugins = []
from plugins.plugins import PluginRunner
plugin_runner = PluginRunner(plugins=plugins)
def make_current_ed_state() -> EveryDreamTrainingState:
return EveryDreamTrainingState(optimizer=ed_optimizer,
train_batch=train_batch,
unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=noise_scheduler,
vae=vae,
unet_ema=unet_ema,
text_encoder_ema=text_encoder_ema)
epoch = None
try:
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
plugin_runner.run_on_training_start(log_folder=log_folder, project_name=args.project_name)
for epoch in range(args.max_epochs):
plugin_runner.run_on_epoch_start(epoch=epoch,
global_step=global_step,
project_name=args.project_name,
log_folder=log_folder,
data_root=args.data_root)
if args.load_settings_every_epoch:
load_train_json_from_file(args)
epoch_len = math.ceil(len(train_batch) / args.batch_size)
plugin_runner.run_on_epoch_start(
epoch=epoch,
global_step=global_step,
epoch_length=epoch_len,
project_name=args.project_name,
log_folder=log_folder,
data_root=args.data_root
)
loss_epoch = []
epoch_start_time = time.time()
images_per_sec_log_step = []
epoch_len = math.ceil(len(train_batch) / args.batch_size)
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
@ -781,16 +1122,18 @@ def main(args):
)
for step, batch in enumerate(train_dataloader):
step_start_time = time.time()
plugin_runner.run_on_step_start(epoch=epoch,
local_step=step,
global_step=global_step,
project_name=args.project_name,
log_folder=log_folder,
batch=batch)
batch=batch,
ed_state=make_current_ed_state())
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True)
del target, model_pred
@ -800,6 +1143,21 @@ def main(args):
ed_optimizer.step(loss, step, global_step)
if args.ema_decay_rate != None:
if ((global_step + 1) % args.ema_update_interval) == 0:
# debug_start_time = time.time() # Measure time
if args.disable_unet_training != True:
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
if args.disable_textenc_training != True:
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
# debug_end_time = time.time() # Measure time
# debug_elapsed_time = debug_end_time - debug_start_time # Measure time
# print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.") # Measure time
loss_step = loss.detach().item()
steps_pbar.set_postfix({"loss/step": loss_step}, {"gs": global_step})
@ -816,7 +1174,7 @@ def main(args):
lr_unet = ed_optimizer.get_unet_lr()
lr_textenc = ed_optimizer.get_textenc_lr()
loss_log_step = []
log_writer.add_scalar(tag="hyperparameter/lr unet", scalar_value=lr_unet, global_step=global_step)
log_writer.add_scalar(tag="hyperparameter/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_step, global_step=global_step)
@ -840,23 +1198,29 @@ def main(args):
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
needs_save = False
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
needs_save = True
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs and epoch >= args.save_ckpts_from_n_epochs:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
needs_save = True
if needs_save:
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(),
save_ckpt_dir=args.save_ckpt_dir, yaml_name=None,
save_full_precision=args.save_full_precision,
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
plugin_runner.run_on_step_end(epoch=epoch,
global_step=global_step,
local_step=step,
project_name=args.project_name,
log_folder=log_folder,
data_root=args.data_root,
batch=batch)
batch=batch,
ed_state=make_current_ed_state())
del batch
global_step += 1
@ -873,8 +1237,9 @@ def main(args):
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
loss_epoch = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
if len(loss_epoch) > 0:
loss_epoch = sum(loss_epoch) / len(loss_epoch)
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
plugin_runner.run_on_epoch_end(epoch=epoch,
global_step=global_step,
@ -882,13 +1247,18 @@ def main(args):
log_folder=log_folder,
data_root=args.data_root)
gc.collect()
gc.collect()
# end of epoch
# end of training
epoch = args.max_epochs
plugin_runner.run_on_training_end()
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(),
save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision,
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -898,7 +1268,9 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = make_save_path(epoch, global_step, prepend="errored-")
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(),
save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision,
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
raise ex
@ -914,14 +1286,7 @@ if __name__ == "__main__":
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, argv = argparser.parse_known_args()
if args.config is not None:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
args.__dict__.update(json.load(f))
if len(argv) > 0:
print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.")
else:
print("No config file specified, using command line args")
load_train_json_from_file(args, report_load=True)
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
@ -936,7 +1301,7 @@ if __name__ == "__main__":
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids")
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
@ -964,15 +1329,27 @@ if __name__ == "__main__":
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, def: 0 (shuffle all)")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15, set to -1 to use zero terminal SNR noising beta schedule instead")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameter is the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_rate) from training, and the ema_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
argparser.add_argument("--ema_strength_target", type=float, default=None, help="EMA decay target value in range (0,1). emarate will be calculated from equation: 'ema_decay_rate=ema_strength_target^(total_steps/ema_update_interval)'. Using this parameter will enable the ema feature and overide ema_decay_rate.")
argparser.add_argument("--ema_update_interval", type=int, default=500, help="How many steps between optimizer steps that EMA decay updates. EMA model will be update on every step modulo grad_accum times ema_update_interval.")
argparser.add_argument("--ema_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using 'cpu' is taking around 4 seconds per update vs fraction of a second on 'cuda'. Using 'cuda' will reserve around 3.2GB VRAM for a model, with 'cpu' the system RAM will be used.")
argparser.add_argument("--ema_sample_nonema_model", action="store_true", default=False, help="Will show samples from non-EMA trained model, just like regular training. Can be used with: --ema_sample_ema_model")
argparser.add_argument("--ema_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. May be slower when using ema cpu offloading. Can be used with: --ema_sample_nonema_model")
argparser.add_argument("--ema_resume_model", type=str, default=None, help="The EMA decay checkpoint to resume from for EMA decay, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1-ema-decay")
# load CLI args to overwrite existing config args
args = argparser.parse_args(args=argv, namespace=args)
main(args)

View File

@ -39,5 +39,15 @@
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.02
"zero_frequency_noise_ratio": 0.02,
"enable_zero_terminal_snr": false,
"load_settings_every_epoch": false,
"min_snr_gamma": null,
"ema_decay_rate": null,
"ema_strength_target": null,
"ema_update_interval": null,
"ema_device": null,
"ema_sample_nonema_model": false,
"ema_sample_ema_model": false,
"ema_resume_model" : null
}

View File

@ -23,6 +23,7 @@
import os.path as osp
import re
from safetensors import safe_open
import torch
@ -288,17 +289,38 @@ def convert(model_path: str, checkpoint_path: str, half: bool):
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
# Convert the UNet model
unet_state_dict = torch.load(unet_path, map_location="cpu")
if osp.exists(unet_path):
unet_state_dict = torch.load(unet_path, map_location="cpu")
else:
unet_state_dict = {}
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
with safe_open(unet_path, framework="pt", device="cpu") as f:
for key in f.keys():
unet_state_dict[key] = f.get_tensor(key)
unet_state_dict = convert_unet_state_dict(unet_state_dict)
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
# Convert the VAE model
vae_state_dict = torch.load(vae_path, map_location="cpu")
if osp.exists(vae_path):
vae_state_dict = torch.load(vae_path, map_location="cpu")
else:
vae_state_dict = {}
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
with safe_open(vae_path, framework="pt", device="cpu") as f:
for key in f.keys():
vae_state_dict[key] = f.get_tensor(key)
vae_state_dict = convert_vae_state_dict(vae_state_dict)
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
# Convert the text encoder model
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
if osp.exists(text_enc_path):
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
else:
text_enc_dict = {}
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
with safe_open(text_enc_path, framework="pt", device="cpu") as f:
for key in f.keys():
text_enc_dict[key] = f.get_tensor(key)
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict

View File

@ -90,7 +90,8 @@ class SampleGenerator:
default_seed: int,
default_sample_steps: int,
use_xformers: bool,
use_penultimate_clip_layer: bool):
use_penultimate_clip_layer: bool,
guidance_rescale: float = 0):
self.log_folder = log_folder
self.log_writer = log_writer
self.batch_size = batch_size
@ -99,6 +100,7 @@ class SampleGenerator:
self.show_progress_bars = False
self.generate_pretrain_samples = False
self.use_penultimate_clip_layer = use_penultimate_clip_layer
self.guidance_rescale = guidance_rescale
self.default_resolution = default_resolution
self.default_seed = default_seed
@ -182,7 +184,7 @@ class SampleGenerator:
self.sample_requests = self._make_random_caption_sample_requests()
@torch.no_grad()
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, extra_info: str = ""):
"""
generates samples at different cfg scales and saves them to disk
"""
@ -231,6 +233,7 @@ class SampleGenerator:
generator=generators,
width=size[0],
height=size[1],
guidance_rescale=self.guidance_rescale
).images
for image in images:
@ -269,15 +272,15 @@ class SampleGenerator:
prompt = prompts[prompt_idx]
clean_prompt = clean_filename(prompt)
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(str(batch[prompt_idx]))
tfimage = transforms.ToTensor()(result)
if batch[prompt_idx].wants_random_caption:
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
self.log_writer.add_image(tag=f"sample_{sample_index}{extra_info}", img_tensor=tfimage, global_step=global_step)
else:
self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
self.log_writer.add_image(tag=f"sample_{sample_index}_{extra_info}{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
sample_index += 1
del result

View File

@ -7,7 +7,7 @@ pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url "http
pip install -U transformers==4.29.2
pip install -U diffusers[torch]==0.18.0
pip install pynvml==11.4.1
pip install -U https://github.com/victorchall/everydream-whls/raw/main/bitsandbytes-0.38.1-py2.py3-none-any.whl
pip install -U pip install -U https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl
pip install ftfy==6.1.1
pip install aiohttp==3.8.4
pip install tensorboard>=2.11.0
@ -23,6 +23,7 @@ pip install compel~=1.1.3
pip install dadaptation
pip install safetensors
pip install open-flamingo==2.0.0
pip install prodigyopt
python utils/get_yamls.py
GOTO :eof