Merge branch 'victorchall:main' into feat_add_sde_samplers
This commit is contained in:
commit
bc1058a0d5
|
@ -124,7 +124,8 @@
|
|||
" 'wandb',\n",
|
||||
" 'colorama',\n",
|
||||
" 'keyboard',\n",
|
||||
" 'lion-pytorch'\n",
|
||||
" 'lion-pytorch',\n",
|
||||
" 'safetensors'\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"print(colored(0, 255, 0, 'Installing packages...'))\n",
|
||||
|
@ -543,7 +544,7 @@
|
|||
" --resolution $Resolution \\\n",
|
||||
" --sample_prompts \"$Sample_File\" \\\n",
|
||||
" --sample_steps $Steps_between_samples \\\n",
|
||||
" --save_every_n_epoch $Save_every_N_epoch \\\n",
|
||||
" --save_every_n_epochs $Save_every_N_epoch \\\n",
|
||||
" --seed $Training_Seed \\\n",
|
||||
" --zero_frequency_noise_ratio $zero_frequency_noise\n",
|
||||
"\n",
|
||||
|
|
|
@ -41,6 +41,7 @@ class EveryDreamBatch(Dataset):
|
|||
seed=555,
|
||||
tokenizer=None,
|
||||
shuffle_tags=False,
|
||||
keep_tags=0,
|
||||
rated_dataset=False,
|
||||
rated_dataset_dropout_target=0.5,
|
||||
name='train'
|
||||
|
@ -54,6 +55,7 @@ class EveryDreamBatch(Dataset):
|
|||
self.tokenizer = tokenizer
|
||||
self.max_token_length = self.tokenizer.model_max_length
|
||||
self.shuffle_tags = shuffle_tags
|
||||
self.keep_tags = keep_tags
|
||||
self.seed = seed
|
||||
self.rated_dataset = rated_dataset
|
||||
self.rated_dataset_dropout_target = rated_dataset_dropout_target
|
||||
|
@ -94,7 +96,7 @@ class EveryDreamBatch(Dataset):
|
|||
)
|
||||
|
||||
if self.shuffle_tags or train_item["shuffle_tags"]:
|
||||
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed)
|
||||
example["caption"] = train_item["caption"].get_shuffled_caption(self.seed, keep_tags=self.keep_tags)
|
||||
else:
|
||||
example["caption"] = train_item["caption"].get_caption()
|
||||
|
||||
|
@ -138,7 +140,59 @@ class EveryDreamBatch(Dataset):
|
|||
def __update_image_train_items(self, dropout_fraction: float):
|
||||
self.image_train_items = self.data_loader.get_shuffled_image_buckets(dropout_fraction)
|
||||
|
||||
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||
class DataLoaderWithFixedBuffer(torch.utils.data.DataLoader):
|
||||
def __init__(self, dataset, buffer_tensor, batch_size:int, max_pixels: int, buffer_dtype: torch.dtype, device="cuda"):
|
||||
color_channels = 3
|
||||
buffer_size = batch_size * color_channels * max_pixels
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
buffer_tensor = torch.empty(buffer_size, dtype=buffer_dtype, device=device).pin_memory()
|
||||
self.buffer_tensor = buffer_tensor
|
||||
logging.info(f"buffer_tensor created with shape: {buffer_tensor.shape}")
|
||||
|
||||
super().__init__(dataset, batch_size=batch_size, shuffle=False, num_workers=min(batch_size, os.cpu_count()), collate_fn=self.fixed_collate_fn)
|
||||
|
||||
def fixed_collate_fn(self, batch):
|
||||
"""
|
||||
Collates images to a pinned buffer returned as a view using actual resolution shape
|
||||
"""
|
||||
images = [example["image"] for example in batch]
|
||||
|
||||
# map the image data to the fixed buffer view
|
||||
w, h = images[0].size
|
||||
for i in range(self.batch_size):
|
||||
image = batch["image"][i]
|
||||
self.buffer_tensor[i*self.buffer_size//self.batch_size:(i+1)*self.buffer_size//self.batch_size] = image.view(-1)
|
||||
images = self.buffer_tensor.view(self.batch_size, 3, w, h)
|
||||
|
||||
captions = [example["caption"] for example in batch]
|
||||
tokens = [example["tokens"] for example in batch]
|
||||
runt_size = batch[0]["runt_size"]
|
||||
|
||||
images = torch.stack(images)
|
||||
images = images.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
ret = {
|
||||
"tokens": torch.stack(tuple(tokens)),
|
||||
"image": images,
|
||||
"captions": captions,
|
||||
"runt_size": runt_size,
|
||||
}
|
||||
del batch
|
||||
return ret
|
||||
|
||||
def build_torch_dataloader2(dataset, batch_size, max_pixels) -> torch.utils.data.DataLoader:
|
||||
dataloader = DataLoaderWithFixedBuffer(
|
||||
dataset,
|
||||
max_pixels=max_pixels,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=min(batch_size, os.cpu_count()),
|
||||
collate_fn=collate_fn
|
||||
)
|
||||
return dataloader
|
||||
|
||||
def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size= batch_size,
|
||||
|
@ -148,7 +202,6 @@ def build_torch_dataloader(dataset, batch_size) -> torch.utils.data.DataLoader:
|
|||
)
|
||||
return dataloader
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collates batches
|
||||
|
|
|
@ -56,7 +56,7 @@ class ImageCaption:
|
|||
def rating(self) -> float:
|
||||
return self.__rating
|
||||
|
||||
def get_shuffled_caption(self, seed: int) -> str:
|
||||
def get_shuffled_caption(self, seed: int, keep_tags: int) -> str:
|
||||
"""
|
||||
returns the caption a string with a random selection of the tags in random order
|
||||
:param seed used to initialize the randomizer
|
||||
|
@ -74,7 +74,7 @@ class ImageCaption:
|
|||
if self.__use_weights:
|
||||
tags_caption = self.__get_weighted_shuffled_tags(seed, self.__tags, self.__tag_weights, max_target_tag_length)
|
||||
else:
|
||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags)
|
||||
tags_caption = self.__get_shuffled_tags(seed, self.__tags, keep_tags)
|
||||
|
||||
return self.__main_prompt + ", " + tags_caption
|
||||
return self.__main_prompt
|
||||
|
@ -111,8 +111,16 @@ class ImageCaption:
|
|||
return caption
|
||||
|
||||
@staticmethod
|
||||
def __get_shuffled_tags(seed: int, tags: list[str]) -> str:
|
||||
random.Random(seed).shuffle(tags)
|
||||
def __get_shuffled_tags(seed: int, tags: list[str], keep_tags: int) -> str:
|
||||
tags = tags.copy()
|
||||
keep_tags = min(keep_tags, 0)
|
||||
|
||||
if len(tags) > keep_tags:
|
||||
fixed_tags = tags[:keep_tags]
|
||||
rest = tags[keep_tags:]
|
||||
random.Random(seed).shuffle(rest)
|
||||
tags = fixed_tags + rest
|
||||
|
||||
return ", ".join(tags)
|
||||
|
||||
class ImageTrainItem:
|
||||
|
@ -306,8 +314,10 @@ class ImageTrainItem:
|
|||
image_aspect = width / height
|
||||
target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect))
|
||||
|
||||
self.is_undersized = (width * height) < (target_wh[0]*1.02 * target_wh[1]*1.02)
|
||||
self.is_undersized = (width != target_wh[0] and height != target_wh[1]) and (width * height) < (target_wh[0]*1.02 * target_wh[1]*1.02)
|
||||
|
||||
self.target_wh = target_wh
|
||||
self.image_size = image.size
|
||||
except Exception as e:
|
||||
self.error = e
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ Start with the [Low VRAM guide](TWEAKING.md) if you are having trouble training
|
|||
|
||||
## Resolution
|
||||
|
||||
You can train resolutions from 512 to 1024 in 64 pixel increments. General results from the community indicate you can push the base model a bit beyond what it was designed for *with enough training*. This will work out better when you have a lot of training data (hundreds+) and enable slightly higher resolution at inference time without seeing repeats in your generated images. This does cost speed of training and higher VRAM use! Ex. 768 takes a significant amount more VRAM than 512, so you will need to compensate for that by reducing ```batch_size```.
|
||||
You can train resolutions from 512 to 1024 in 64 pixel increments. General results from the community indicate you can push the base model a bit beyond what it was designed for *with enough training*. This will work out better when you have a lot of training data (hundreds+) and enable slightly higher resolution at inference time without seeing repeats in your generated images. This does cost speed of training and higher VRAM use! Ex. 768 takes a significant amount of additional VRAM than 512, so you will need to compensate for that by reducing ```batch_size```.
|
||||
|
||||
--resolution 640 ^
|
||||
|
||||
|
@ -14,21 +14,21 @@ For instance, if training from the base 1.5 model, you can try trying at 576, 64
|
|||
|
||||
If you are training on a base model that is 768, such as "SD 2.1 768-v", you should also probably use 768 as a base number and adjust from there.
|
||||
|
||||
Some results from the community seem to indicate training at a higher resolution on SD1.x models may increase how fast the model learns, and it may be a good idea to slightly reduce your learning rate as you increase resolution. My suspcision is that the higher resolutions increase the gradients as more information is presented to the model per image.
|
||||
Some results from the community seem to indicate training at a higher resolution on SD1.x models may increase how fast the model learns, and it may be a good idea to slightly reduce your learning rate as you increase resolution. My suspicion is that the higher resolutions increase the gradients as more information is presented to the model per image.
|
||||
|
||||
You may need to experiment with LR as you increase resolution. I don't have a perfect rule of thumb here, but I might suggest if you train SD1.5 which is a 512 model at resolution 768 you reduce your LR by about half. ED2 tends to prefer ~2e-6 to ~5e-6 for normal 512 training on SD1.X models around batch 6-8, so if you train SD1.X at 768 consider 1e-6 to 2.5e-6 instead.
|
||||
You may need to experiment with the LR as you increase resolution. I don't have a perfect rule of thumb here, but I might suggest if you train SD1.5 which is a 512 model at resolution 768 you reduce your LR by about half. ED2 tends to prefer ~2e-6 to ~5e-6 for normal 512 training on SD1.X models around batch 6-8, so if you train SD1.X at 768 consider 1e-6 to 2.5e-6 instead.
|
||||
|
||||
## Log and ckpt save folders
|
||||
|
||||
If you want to use a nondefault location for saving logs or ckpt files, these:
|
||||
|
||||
Logdir defaults to the "logs" folder in the trainer directory. If you wan to save all logs (including diffuser copies of ckpts, sample images, and tensbooard events) use this:
|
||||
Logdir defaults to the "logs" folder in the trainer directory. If you want to save all logs (including diffuser copies of ckpts, sample images, and tensbooard events) use this:
|
||||
|
||||
--logdir "/workspace/mylogs"
|
||||
|
||||
Remember to use the same folder when you launch tensorboard (```tensorboard --logdir "/worksapce/mylogs"```) or it won't find your logs.
|
||||
|
||||
By default the CKPT format copies of ckpts that are peroidically saved are saved in the trainer root folder. If you want to save them elsewhere, use this:
|
||||
By default the CKPT format copies of ckpts that are periodically saved are saved in the trainer root folder. If you want to save them elsewhere, use this:
|
||||
|
||||
--save_ckpt_dir "r:\webui\models\stable-diffusion"
|
||||
|
||||
|
@ -125,11 +125,11 @@ Seed can be used to make training either more or less deterministic. The seed v
|
|||
|
||||
To use a random seed, use -1:
|
||||
|
||||
-- seed -1
|
||||
--seed -1
|
||||
|
||||
Default behavior is to use a fixed seed of 555. The seed you set is fixed for all samples if you set a value other than -1. If you set a seed it is also incrememted for shuffling your training data every epoch (i.e. 555, 556, 557, etc). This makes training more deterministic. I suggest a fixed seed when you are trying A/B test tweaks to your general training setup, or when you want all your test samples to use the same seed.
|
||||
|
||||
Fixed seed should be using when performing A/B tests or hyperparameter sweeps. Random seed (-1) may be better if you are stopping and resuming training often so every restart is using random values for all of the various randomness sources used in training such as noising and data shuffling.
|
||||
Fixed seed should be used when performing A/B tests or hyperparameter sweeps. Random seed (-1) may be better if you are stopping and resuming training often so every restart is using random values for all of the various randomness sources used in training such as noising and data shuffling.
|
||||
|
||||
## Shuffle tags
|
||||
|
||||
|
@ -139,6 +139,12 @@ For those training booru tagged models, you can use this arg to randomly (but de
|
|||
|
||||
This simply chops the captions in to parts based on the commas and shuffles the order.
|
||||
|
||||
In case you want to keep static the first N tags, you can also add this parameter (`--shuffle_tags` must also be set):
|
||||
|
||||
--keep_tags 4 ^
|
||||
|
||||
The above example will keep static the 4 first additional tags, and shuffle the rest.
|
||||
|
||||
## Zero frequency noise
|
||||
|
||||
Based on [Nicholas Guttenberg's blog post](https://www.crosslabs.org//blog/diffusion-with-offset-noise) zero frequency noise offsets the noise added to the image during training/denoising, which can help improve contrast and the ability to render very dark or very bright scenes more accurately, and may help slightly with color saturation.
|
||||
|
@ -149,7 +155,7 @@ Based on [Nicholas Guttenberg's blog post](https://www.crosslabs.org//blog/diffu
|
|||
|
||||
Test results: https://huggingface.co/panopstor/ff7r-stable-diffusion/blob/main/zero_freq_test_biggs.webp
|
||||
|
||||
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time.
|
||||
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergence over time.
|
||||
|
||||
## Zero terminal SNR
|
||||
|
||||
|
@ -205,3 +211,67 @@ Clips the gradient normals to a maximum value. Default is None (no clipping).
|
|||
|
||||
Default is no gradient normal clipping. There are also other ways to deal with gradient explosion, such as increasing optimizer epsilon.
|
||||
|
||||
## Zero Terminal SNR
|
||||
**Parameter:** `--enable_zero_terminal_snr`
|
||||
**Default:** `False`
|
||||
To enable zero terminal SNR.
|
||||
|
||||
## Dynamic Configuration Loading
|
||||
**Parameter:** `--load_settings_every_epoch`
|
||||
**Default:** `False`
|
||||
Most of the parameters in the train.json file CANNOT be modified during training. Activate this to have the `train.json` configuration file reloaded at the start of each epoch. The following parameter can be changed and will be applied after the start of a new epoch:
|
||||
- `--save_every_n_epochs`
|
||||
- `--save_ckpts_from_n_epochs`
|
||||
- `--save_full_precision`
|
||||
- `--save_optimizer`
|
||||
- `--zero_frequency_noise_ratio`
|
||||
- `--min_snr_gamma`
|
||||
- `--clip_skip`
|
||||
|
||||
## Min-SNR-Gamma Parameter
|
||||
**Parameter:** `--min_snr_gamma`
|
||||
**Recommended Values:** 5, 1, 20
|
||||
**Default:** `None`
|
||||
To enable min-SNR-Gamma. For an in-depth understanding, consult this [research paper](https://arxiv.org/abs/2303.09556).
|
||||
|
||||
## EMA Decay Features
|
||||
The Exponential Moving Average (EMA) model is copied from the base model at the start and is updated every interval of steps by a small contribution from training.
|
||||
In this mode, the EMA model will be saved alongside the regular checkpoint from training. Normal training checkpoint can be loaded with `--resume_ckpt`, and the EMA model can be loaded with `--ema_decay_resume_model`.
|
||||
For more information, consult the [research paper](https://arxiv.org/abs/2101.08482) or continue reading the tuning notes below.
|
||||
**Parameters:**
|
||||
- `--ema_decay_rate`: Determines the EMA decay rate. It defines how much the EMA model is updated from training at each update. Values should be close to 1 but not exceed it. Activating this parameter triggers the EMA decay feature.
|
||||
- `--ema_strength_target`: Set the EMA strength target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA feature. See [ema_strength_target](#ema_strength_target) for more information.
|
||||
- `--ema_update_interval`: Set the interval in steps between EMA updates. The update occurs at each optimizer step. If you use grad_accum, actual update interval will be multipled by your grad_accum value.
|
||||
- `--ema_device`: Choose between `cpu` and `cuda` for EMA. Opting for 'cpu' takes around 4 seconds per update and uses approximately 3.2GB RAM, while 'cuda' is much faster but requires a similar amount of VRAM.
|
||||
- `--ema_sample_nonema_model`: Activate to display samples from the non-ema trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled.
|
||||
- `--ema_sample_ema_model`: Turn on to exhibit samples from the EMA model. EMA models will be used for samples generations by default with EMA decay enabled, unless disabled.
|
||||
- `--ema_resume_model`: Indicate the EMA decay checkpoint to continue from, working like `--resume_ckpt` but will load EMA model. Using `findlast` will only load EMA version and not regular training.
|
||||
|
||||
## Notes on tuning EMA.
|
||||
|
||||
The purpose of EMA is to reduce the effect of the data from the tail end of training from having an overly powerful effect on the model. Normally trainig is stopped abruptly and the final images seen by the trainer may have a stronger effect than images seen earlier in training. *This may have a similar to lowering the learning rate near the end of training, but is not mathematically equivalent.* An alternative method to EMA would be to use a cosine learning rate schedule.
|
||||
|
||||
Training with EMA turned on has no effect on the non-EMA model if all other settings are identical, though practical considerations (mainly VRAM limits) may cause you to change other settings which can affect the non-EMA model, such as lowering batch size to free enough VRAM for the EMA model if using gpu.
|
||||
|
||||
A standard implementation of EMA uses a decay rate of 0.9999, GPU device, and an interval of 1 (every optimizer step). This value can have a strong effect, leading to what appears to be an undertrained EMA model compared to the non-EMA model. A value of 0.999 seems to produce an EMA model nearly identical to the non-EMA model and should be considered a low value. Somewhere in the 0.999-0.9999 range is suggested when using GPU and interval 1.
|
||||
|
||||
EMA uses an additional ~3.2GB of RAM (for SD1.x models) to store an extra copy of the model weights in memory. For even 24GB consumer GPUs this is substantial, but EMA CPU offloading together with using a higher `ema_update_interval` can make it more practical. It can be practical on a 24GB GPU if you also enable gradient checkpointing, which is not normally suggested for 24GB GPUs as it is not necessary. Gradient checkpointing saves a bit more than 3.2GB itself. The other options is to use CPU offloading by setting `ema_device: "cpu"`. The EMA copy of the model will be stored in main system memory instead of the GPU, but at a cost of slower sampling and updating. CPU offloading is a requirement for GPUs with 16GB or less VRAM, and even 24GB GPU users may wish to consider it. If you are using a 40GB+ GPU you should use GPU.
|
||||
|
||||
When using a higher interval to make cpu offloading practical and reasonably fast, the decay rate should be lowered. For instance, with an interval of 50, you may wish to lower the decay rate to 0.99 or possibly lower. This is because the EMA model is updated less frequently and the decay rate is effectively higher than the set value under otherwise "normal" EMA training regime. The higher interval also reduces accuracy of the EMA model compared to the reference implementation which would normally update EMA every optimizer step.
|
||||
|
||||
I would suggest you pick an interval and stick with it, and then tune your decay_rate by generating samples from both EMA and non EMA using the options or after training using your favorite inference app and compare the results.
|
||||
It is expected the EMA model will look "behind" on training, but should still be recognizable as the same subject matter. If it is not, you may wish to try a lower decay rate. If it is too close to the non-EMA model, you may wish to try a higher decay rate.
|
||||
|
||||
Using the GPU for ema incurs only a small speed penalty of around 5-10% with all else being equal, though if you change other parameters such as lowering batch size or enabling gradient checkpointing flag to free VRAM for EMA those options may incur a slightly higher speed penalities.
|
||||
|
||||
Generally, I recommend picking a device and approriate interval given your device choice first and stick with those values, then tweak the `ema_decay_rate` up or down according to how you want the EMA model to look vs. your non-EMA model. From there, if your EMA model seems to "lag behind" the non-EMA model by "too much" (subjectively judged), you can decrease decay rate. If it identical or nearly identical, use a slightly higher value.
|
||||
|
||||
## ema_strength_target
|
||||
|
||||
This arg is a non-standard way of calculating the actual decay rate used. It attempts to calculate a value for decay rate based on your `ema_update_interval` and the total length of training, compensating for both. Values of 0.01-0.15 should work, with higher values leading to a EMA model that deviates more from the non-EMA model similar to how decay rate works. It attempts to be more of a "strength" value of EMA, or "how much" (as a factor, i.e. 0.10 = 10% "strength") of the EMA model are kept for the totality of training.
|
||||
|
||||
While the calculation makes sense in how it compensates for inteval and total training length, it is not a standard way of calculating decay rate and there will not be information online about how to use it. I recommend not using this feature and instead picking a device and approriate interval given your device choice first, then tuning your decay rate by hand, find "good" values, then don't mess with them, but you can try this feature out if you want.
|
||||
|
||||
--ema_strength_target 0.10 ^
|
||||
|
||||
If you use `ema_strength_target` the actual calculated `ema_decay_rate` used will be printed in your logs, and you should pay attention to this value and use it to inform your future decisions on EMA tuning.
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
`python caption_fl.py --data_root input --min_new_tokens 20 --max_new_tokens 30 --num_beams 3 --model "openflamingo/OpenFlamingo-9B-vitl-mpt7b"`
|
||||
|
||||
This script uses two example image/caption pairs located in the `/example` folder to prime the system to caption, then captions the images in the input folder. It will save a `.txt` file of the same base filename with the captoin in the same folder.
|
||||
This script uses two example image/caption pairs located in the `/example` folder to prime the system to caption, then captions the images in the input folder. It will save a `.txt` file of the same base filename with the caption in the same folder.
|
||||
|
||||
This script currently requires an AMPERE or newer GPU due to using bfloat16.
|
||||
|
||||
|
|
|
@ -29,12 +29,13 @@ For each of the `unet` and `text_encoder` sections, you can set the following pr
|
|||
Standard full precision AdamW optimizer exposed by PyTorch. Not recommended. Slower and uses more memory than adamw8bit. Widely documented on the web.
|
||||
|
||||
* adamw8bit
|
||||
* lion8bit
|
||||
|
||||
Tim Dettmers / bitsandbytes AdamW 8bit optimizer. This is the default and recommended setting. Widely documented on the web.
|
||||
Tim Dettmers / bitsandbytes AdamW and Lion 8bit optimizer. adamw8bit is the default and recommended setting as it is well understood, and lion8bit is very vram efficient. Widely documented on the web.
|
||||
|
||||
* lion
|
||||
|
||||
Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the [lion optimizer](https://arxiv.org/abs/2302.06675). Click links to read more. `Epsilon` is not used by lion.
|
||||
Lucidrains' [implementation](https://github.com/lucidrains/lion-pytorch) of the [lion optimizer](https://arxiv.org/abs/2302.06675). Click links to read more. `Epsilon` is not used by lion. You should prefer lion8bit over this optimizer as it is more memory efficient.
|
||||
|
||||
Recommended settings for lion based on the paper are as follows:
|
||||
|
||||
|
@ -61,7 +62,13 @@ Available optimizer values for Dadaptation are:
|
|||
|
||||
* dadapt_lion, dadapt_adam, dadapt_sgd
|
||||
|
||||
These are fairly experimental but tested as working. Gradient checkpointing may be required even on 24GB GPUs. Performance is slower than the compiled and optimized AdamW8bit optimizer unless you increae gradient accumulation as it seems the accumulation steps process slowly with the current implementation of D-Adaption
|
||||
These are fairly experimental but tested as working. Gradient checkpointing may be required even on 24GB GPUs. Performance is slower than the compiled and optimized AdamW8bit optimizer unless you increae gradient accumulation as it seems the accumulation steps process slowly with the current implementation of D-Adaption.
|
||||
|
||||
#### Prodigy
|
||||
|
||||
Another adaptive optimizer. It is not very VRAM efficient. [Github](https://github.com/konstmish/prodigy), [Paper](https://arxiv.org/pdf/2306.06101.pdf)
|
||||
|
||||
* prodigy
|
||||
|
||||
## Optimizer parameters
|
||||
|
||||
|
|
|
@ -65,3 +65,4 @@ The effect of the limit is that the caption will always be truncated when the ma
|
|||
exceeded. This process does not consider if the cutoff is in the middle of a tag or even in the middle of a
|
||||
word if it is translated into several tokens.
|
||||
|
||||
To mitigate this token limitation (when not using weighted shuffling), the `--keep_tags n` parameter can be employed. This ensures that the first n tags following the initial chunk remain static, while the remaining tags are shuffled.
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
aiohttp==3.8.4
|
||||
bitsandbytes==0.38.1
|
||||
bitsandbytes==0.41.1
|
||||
scipy
|
||||
colorama==0.4.6
|
||||
compel~=1.1.3
|
||||
ftfy==6.1.1
|
||||
|
@ -14,4 +15,5 @@ pynvml==11.5.0
|
|||
speedtest-cli
|
||||
tensorboard==2.12.0
|
||||
wandb
|
||||
safetensors
|
||||
safetensors
|
||||
prodigyopt
|
||||
|
|
|
@ -276,6 +276,7 @@ class EveryDreamOptimizer():
|
|||
decouple = True # seems bad to turn off, dadapt_adam only
|
||||
momentum = 0.0 # dadapt_sgd
|
||||
no_prox = False # ????, dadapt_adan
|
||||
use_bias_correction = True # suggest by prodigy github
|
||||
growth_rate=float("inf") # dadapt various, no idea what a sane default is
|
||||
|
||||
if local_optimizer_config is not None:
|
||||
|
@ -307,6 +308,30 @@ class EveryDreamOptimizer():
|
|||
betas=(betas[0], betas[1]),
|
||||
weight_decay=weight_decay,
|
||||
)
|
||||
elif optimizer_name == "lion8bit":
|
||||
from bitsandbytes.optim import Lion8bit
|
||||
opt_class = Lion8bit
|
||||
optimizer = opt_class(
|
||||
itertools.chain(parameters),
|
||||
lr=curr_lr,
|
||||
betas=(betas[0], betas[1]),
|
||||
weight_decay=weight_decay,
|
||||
percentile_clipping=100,
|
||||
min_8bit_size=4096,
|
||||
)
|
||||
elif optimizer_name == "prodigy":
|
||||
from prodigyopt import Prodigy
|
||||
opt_class = Prodigy
|
||||
safeguard_warmup = True # per recommendation from prodigy documentation
|
||||
optimizer = opt_class(
|
||||
itertools.chain(parameters),
|
||||
lr=curr_lr,
|
||||
weight_decay=weight_decay,
|
||||
use_bias_correction=use_bias_correction,
|
||||
growth_rate=growth_rate,
|
||||
d0=d0,
|
||||
safeguard_warmup=safeguard_warmup
|
||||
)
|
||||
elif optimizer_name == "adamw":
|
||||
opt_class = torch.optim.AdamW
|
||||
if "dowg" in optimizer_name:
|
||||
|
@ -317,7 +342,7 @@ class EveryDreamOptimizer():
|
|||
elif optimizer_name == "scalar_dowg":
|
||||
opt_class = dowg.ScalarDoWG
|
||||
else:
|
||||
raise ValueError(f"Unknown DoWG optimizer {optimizer_name}. Available options are coordinate_dowg and scalar_dowg")
|
||||
raise ValueError(f"Unknown DoWG optimizer {optimizer_name}. Available options are 'coordinate_dowg' and 'scalar_dowg'")
|
||||
elif optimizer_name in ["dadapt_adam", "dadapt_lion", "dadapt_sgd"]:
|
||||
import dadaptation
|
||||
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
import math
|
||||
import os
|
||||
import shutil
|
||||
from plugins.plugins import BasePlugin
|
||||
from train import save_model
|
||||
|
||||
EVERY_N_EPOCHS = 1 # how often to save. integers >= 1 save at the end of every nth epoch. floats < 1 subdivide the epoch evenly (eg 0.33 = 3 subdivisions)
|
||||
|
||||
class InterruptiblePlugin(BasePlugin):
|
||||
|
||||
def __init__(self):
|
||||
print("Interruptible plugin instantiated")
|
||||
self.previous_save_path = None
|
||||
self.every_n_epochs = EVERY_N_EPOCHS
|
||||
|
||||
def on_epoch_start(self, **kwargs):
|
||||
epoch = kwargs['epoch']
|
||||
epoch_length = kwargs['epoch_length']
|
||||
self.steps_to_save_this_epoch = self._get_save_step_indices(epoch, epoch_length)
|
||||
|
||||
def on_step_end(self, **kwargs):
|
||||
local_step = kwargs['local_step']
|
||||
if local_step in self.steps_to_save_this_epoch:
|
||||
global_step = kwargs['global_step']
|
||||
epoch = kwargs['epoch']
|
||||
project_name = kwargs['project_name']
|
||||
log_folder = kwargs['log_folder']
|
||||
ckpt_name = f"rolling-{project_name}-ep{epoch:02}-gs{global_step:05}"
|
||||
save_path = os.path.join(log_folder, "ckpts", ckpt_name)
|
||||
print(f"{type(self)} saving model to {save_path}")
|
||||
save_model(save_path, global_step=global_step, ed_state=kwargs['ed_state'], save_ckpt_dir=None, yaml_name=None, save_ckpt=False, save_full_precision=True, save_optimizer_flag=True)
|
||||
self._remove_previous()
|
||||
self.previous_save_path = save_path
|
||||
|
||||
def on_training_end(self, **kwargs):
|
||||
self._remove_previous()
|
||||
|
||||
def _remove_previous(self):
|
||||
if self.previous_save_path is not None:
|
||||
shutil.rmtree(self.previous_save_path, ignore_errors=True)
|
||||
self.previous_save_path = None
|
||||
|
||||
def _get_save_step_indices(self, epoch, epoch_length_steps: int) -> list[int]:
|
||||
if self.every_n_epochs >= 1:
|
||||
if ((epoch+1) % self.every_n_epochs) == 0:
|
||||
# last step only
|
||||
return [epoch_length_steps-1]
|
||||
else:
|
||||
return []
|
||||
else:
|
||||
# subdivide the epoch evenly, by rounding self.every_n_epochs to the nearest clean division of steps
|
||||
num_divisions = max(1, min(epoch_length_steps, round(1/self.every_n_epochs)))
|
||||
# validation happens after training:
|
||||
# if an epoch has eg 100 steps and num_divisions is 2, then validation should occur after steps 49 and 99
|
||||
validate_every_n_steps = epoch_length_steps / num_divisions
|
||||
return [math.ceil((i+1)*validate_every_n_steps) - 1 for i in range(num_divisions)]
|
|
@ -44,7 +44,7 @@ class Timer:
|
|||
def __exit__(self, type, value, traceback):
|
||||
elapsed_time = time.time() - self.start
|
||||
if elapsed_time > self.warn_seconds:
|
||||
logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.limit} seconds')
|
||||
logging.warning(f'Execution of {self.label} took {elapsed_time} seconds which is longer than the limit of {self.warn_seconds} seconds')
|
||||
|
||||
|
||||
class PluginRunner:
|
||||
|
|
|
@ -3,7 +3,7 @@ torchvision==0.15.2
|
|||
transformers==4.29.2
|
||||
diffusers[torch]==0.18.0
|
||||
pynvml==11.4.1
|
||||
bitsandbytes==0.38.1
|
||||
bitsandbytes==0.41.1
|
||||
ftfy==6.1.1
|
||||
aiohttp==3.8.4
|
||||
tensorboard>=2.11.0
|
||||
|
@ -21,4 +21,4 @@ numpy==1.23.5
|
|||
wandb
|
||||
colorama
|
||||
safetensors
|
||||
open-flamingo==2.0.0
|
||||
open-flamingo==2.0.0
|
||||
|
|
|
@ -4,6 +4,7 @@ import pathlib
|
|||
import PIL.Image as Image
|
||||
|
||||
from data.image_train_item import ImageCaption, ImageTrainItem
|
||||
import data.aspects as aspects
|
||||
|
||||
DATA_PATH = pathlib.Path('./test/data')
|
||||
|
||||
|
@ -32,4 +33,70 @@ class TestImageCaption(unittest.TestCase):
|
|||
self.assertEqual(caption.get_caption(), "hello world, one, two, three")
|
||||
|
||||
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
||||
self.assertEqual(caption.get_caption(), "hello world")
|
||||
|
||||
class TestImageTrainItemConstructor(unittest.TestCase):
|
||||
|
||||
def tearDown(self) -> None:
|
||||
for file in DATA_PATH.glob("img_*"):
|
||||
file.unlink()
|
||||
|
||||
return super().tearDown()
|
||||
|
||||
@staticmethod
|
||||
def image_with_size(width, height):
|
||||
filename = DATA_PATH / "img_{}x{}.jpg".format(width, height)
|
||||
Image.new("RGB", (width, height)).save(filename)
|
||||
caption = ImageCaption("hello world", 1.0, [], [], 2048, False)
|
||||
return ImageTrainItem(None, caption, aspects.ASPECTS_512, filename, 0.0, 1.0, False, False, 0)
|
||||
|
||||
def test_target_size_computation(self):
|
||||
# Square images
|
||||
image = self.image_with_size(30, 30)
|
||||
self.assertEqual(image.target_wh, [512,512])
|
||||
self.assertTrue(image.is_undersized)
|
||||
self.assertEqual(image.image_size, (30,30))
|
||||
|
||||
image = self.image_with_size(512, 512)
|
||||
self.assertEqual(image.target_wh, [512,512])
|
||||
self.assertFalse(image.is_undersized)
|
||||
self.assertEqual(image.image_size, (512,512))
|
||||
|
||||
image = self.image_with_size(580, 580)
|
||||
self.assertEqual(image.target_wh, [512,512])
|
||||
self.assertFalse(image.is_undersized)
|
||||
self.assertEqual(image.image_size, (580,580))
|
||||
|
||||
# Landscape images
|
||||
image = self.image_with_size(64, 38)
|
||||
self.assertEqual(image.target_wh, [640,384])
|
||||
self.assertTrue(image.is_undersized)
|
||||
self.assertEqual(image.image_size, (64,38))
|
||||
|
||||
image = self.image_with_size(640, 384)
|
||||
self.assertEqual(image.target_wh, [640,384])
|
||||
self.assertFalse(image.is_undersized)
|
||||
self.assertEqual(image.image_size, (640,384))
|
||||
|
||||
image = self.image_with_size(704, 422)
|
||||
self.assertEqual(image.target_wh, [640,384])
|
||||
self.assertFalse(image.is_undersized)
|
||||
self.assertEqual(image.image_size, (704,422))
|
||||
|
||||
# Portrait images
|
||||
image = self.image_with_size(38, 64)
|
||||
self.assertEqual(image.target_wh, [384,640])
|
||||
self.assertTrue(image.is_undersized)
|
||||
self.assertEqual(image.image_size, (38,64))
|
||||
|
||||
image = self.image_with_size(384, 640)
|
||||
self.assertEqual(image.target_wh, [384,640])
|
||||
self.assertFalse(image.is_undersized)
|
||||
self.assertEqual(image.image_size, (384,640))
|
||||
|
||||
image = self.image_with_size(422, 704)
|
||||
self.assertEqual(image.target_wh, [384,640])
|
||||
self.assertFalse(image.is_undersized)
|
||||
self.assertEqual(image.image_size, (422,704))
|
||||
|
||||
|
||||
|
|
18
train.json
18
train.json
|
@ -4,7 +4,7 @@
|
|||
"clip_grad_norm": null,
|
||||
"clip_skip": 0,
|
||||
"cond_dropout": 0.04,
|
||||
"data_root": "X:\\my_project_data\\project_abc",
|
||||
"data_root": "/mnt/q/training_samples/ff7r/man",
|
||||
"disable_amp": false,
|
||||
"disable_textenc_training": false,
|
||||
"disable_xformers": false,
|
||||
|
@ -19,12 +19,12 @@
|
|||
"lr_decay_steps": 0,
|
||||
"lr_scheduler": "constant",
|
||||
"lr_warmup_steps": null,
|
||||
"max_epochs": 30,
|
||||
"max_epochs": 1,
|
||||
"notebook": false,
|
||||
"optimizer_config": "optimizer.json",
|
||||
"project_name": "project_abc",
|
||||
"resolution": 512,
|
||||
"resume_ckpt": "sd_v1-5_vae",
|
||||
"resume_ckpt": "panopstor/EveryDream",
|
||||
"run_name": null,
|
||||
"sample_prompts": "sample_prompts.txt",
|
||||
"sample_steps": 300,
|
||||
|
@ -40,5 +40,15 @@
|
|||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
"rated_dataset_target_dropout_percent": 50,
|
||||
"zero_frequency_noise_ratio": 0.02
|
||||
"zero_frequency_noise_ratio": 0.02,
|
||||
"enable_zero_terminal_snr": false,
|
||||
"load_settings_every_epoch": false,
|
||||
"min_snr_gamma": null,
|
||||
"ema_decay_rate": null,
|
||||
"ema_strength_target": null,
|
||||
"ema_update_interval": null,
|
||||
"ema_device": null,
|
||||
"ema_sample_nonema_model": false,
|
||||
"ema_sample_ema_model": false,
|
||||
"ema_resume_model" : null
|
||||
}
|
||||
|
|
579
train.py
579
train.py
|
@ -27,6 +27,7 @@ import gc
|
|||
import random
|
||||
import traceback
|
||||
import shutil
|
||||
from typing import Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch.cuda.amp import autocast
|
||||
|
@ -61,6 +62,7 @@ from utils.convert_diff_to_ckpt import convert as converter
|
|||
from utils.isolate_rng import isolate_rng
|
||||
from utils.check_git import check_git
|
||||
from optimizer.optimizers import EveryDreamOptimizer
|
||||
from copy import deepcopy
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from utils.gpu import GPU
|
||||
|
@ -101,6 +103,108 @@ def convert_to_hf(ckpt_path):
|
|||
is_sd1attn, yaml = get_attn_yaml(ckpt_path)
|
||||
return ckpt_path, is_sd1attn, yaml
|
||||
|
||||
class EveryDreamTrainingState:
|
||||
def __init__(self,
|
||||
optimizer: EveryDreamOptimizer,
|
||||
train_batch: EveryDreamBatch,
|
||||
unet: UNet2DConditionModel,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
scheduler,
|
||||
vae: AutoencoderKL,
|
||||
unet_ema: Optional[UNet2DConditionModel],
|
||||
text_encoder_ema: Optional[CLIPTextModel]
|
||||
):
|
||||
self.optimizer = optimizer
|
||||
self.train_batch = train_batch
|
||||
self.unet = unet
|
||||
self.text_encoder = text_encoder
|
||||
self.tokenizer = tokenizer
|
||||
self.scheduler = scheduler
|
||||
self.vae = vae
|
||||
self.unet_ema = unet_ema
|
||||
self.text_encoder_ema = text_encoder_ema
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def save_model(save_path, ed_state: EveryDreamTrainingState, global_step: int, save_ckpt_dir, yaml_name,
|
||||
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
|
||||
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
|
||||
nonlocal save_ckpt_dir
|
||||
nonlocal save_full_precision
|
||||
nonlocal yaml_name
|
||||
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||
save_ckpt_dir = os.curdir
|
||||
|
||||
half = not save_full_precision
|
||||
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
|
||||
if yaml_name and yaml_name != "v1-inference.yaml":
|
||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml"
|
||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
|
||||
if global_step is None or global_step == 0:
|
||||
logging.warning(" No model to save, something likely blew up on startup, not saving")
|
||||
return
|
||||
|
||||
if args.ema_decay_rate != None:
|
||||
pipeline_ema = StableDiffusionPipeline(
|
||||
vae=ed_state.vae,
|
||||
text_encoder=ed_state.text_encoder_ema,
|
||||
tokenizer=ed_state.tokenizer,
|
||||
unet=ed_state.unet_ema,
|
||||
scheduler=ed_state.scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
|
||||
diffusers_model_path = save_path + "_ema"
|
||||
logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}")
|
||||
pipeline_ema.save_pretrained(diffusers_model_path)
|
||||
|
||||
if save_ckpt:
|
||||
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
|
||||
|
||||
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
|
||||
|
||||
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=ed_state.vae,
|
||||
text_encoder=ed_state.text_encoder,
|
||||
tokenizer=ed_state.tokenizer,
|
||||
unet=ed_state.unet,
|
||||
scheduler=ed_state.scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
|
||||
diffusers_model_path = save_path
|
||||
logging.info(f" * Saving diffusers model to {diffusers_model_path}")
|
||||
pipeline.save_pretrained(diffusers_model_path)
|
||||
|
||||
if save_ckpt:
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
|
||||
|
||||
if save_optimizer_flag:
|
||||
logging.info(f" Saving optimizer state to {save_path}")
|
||||
ed_state.optimizer.save(save_path)
|
||||
|
||||
|
||||
def setup_local_logger(args):
|
||||
"""
|
||||
configures logger with file and console logging, logs args, and returns the datestamp
|
||||
|
@ -186,7 +290,7 @@ def set_args_12gb(args):
|
|||
logging.info(" - Overiding resolution to max 512")
|
||||
args.resolution = 512
|
||||
|
||||
def find_last_checkpoint(logdir):
|
||||
def find_last_checkpoint(logdir, is_ema=False):
|
||||
"""
|
||||
Finds the last checkpoint in the logdir, recursively
|
||||
"""
|
||||
|
@ -196,6 +300,12 @@ def find_last_checkpoint(logdir):
|
|||
for root, dirs, files in os.walk(logdir):
|
||||
for file in files:
|
||||
if os.path.basename(file) == "model_index.json":
|
||||
|
||||
if is_ema and (not root.endswith("_ema")):
|
||||
continue
|
||||
elif (not is_ema) and root.endswith("_ema"):
|
||||
continue
|
||||
|
||||
curr_date = os.path.getmtime(os.path.join(root,file))
|
||||
|
||||
if last_date is None or curr_date > last_date:
|
||||
|
@ -228,12 +338,20 @@ def setup_args(args):
|
|||
# find the last checkpoint in the logdir
|
||||
args.resume_ckpt = find_last_checkpoint(args.logdir)
|
||||
|
||||
if (args.ema_resume_model != None) and (args.ema_resume_model == "findlast"):
|
||||
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last EMA decay checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
|
||||
|
||||
args.ema_resume_model = find_last_checkpoint(args.logdir, is_ema=True)
|
||||
|
||||
if args.lowvram:
|
||||
set_args_12gb(args)
|
||||
|
||||
if not args.shuffle_tags:
|
||||
args.shuffle_tags = False
|
||||
|
||||
if not args.keep_tags:
|
||||
args.keep_tags = 0
|
||||
|
||||
args.clip_skip = max(min(4, args.clip_skip), 0)
|
||||
|
||||
if args.useadam8bit:
|
||||
|
@ -356,6 +474,74 @@ def log_args(log_writer, args):
|
|||
arglog += f"{arg}={value}, "
|
||||
log_writer.add_text("config", arglog)
|
||||
|
||||
def update_ema(model, ema_model, decay, default_device, ema_device):
|
||||
with torch.no_grad():
|
||||
original_model_on_proper_device = model
|
||||
need_to_delete_original = False
|
||||
if ema_device != default_device:
|
||||
original_model_on_other_device = deepcopy(model)
|
||||
original_model_on_proper_device = original_model_on_other_device.to(ema_device, dtype=model.dtype)
|
||||
del original_model_on_other_device
|
||||
need_to_delete_original = True
|
||||
|
||||
params = dict(original_model_on_proper_device.named_parameters())
|
||||
ema_params = dict(ema_model.named_parameters())
|
||||
|
||||
for name in ema_params:
|
||||
#ema_params[name].data.mul_(decay).add_(params[name].data, alpha=1 - decay)
|
||||
ema_params[name].data = ema_params[name] * decay + params[name].data * (1.0 - decay)
|
||||
|
||||
if need_to_delete_original:
|
||||
del(original_model_on_proper_device)
|
||||
|
||||
def compute_snr(timesteps, noise_scheduler):
|
||||
"""
|
||||
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
||||
"""
|
||||
minimal_value = 1e-9
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod
|
||||
# Use .any() to check if any elements in the tensor are zero
|
||||
if (alphas_cumprod[:-1] == 0).any():
|
||||
logging.warning(
|
||||
f"Alphas cumprod has zero elements! Resetting to {minimal_value}.."
|
||||
)
|
||||
alphas_cumprod[alphas_cumprod[:-1] == 0] = minimal_value
|
||||
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
# Expand the tensors.
|
||||
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
|
||||
timesteps
|
||||
].float()
|
||||
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
||||
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
|
||||
device=timesteps.device
|
||||
)[timesteps].float()
|
||||
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
||||
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
||||
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
||||
|
||||
# Compute SNR, first without epsilon
|
||||
snr = (alpha / sigma) ** 2
|
||||
# Check if the first element in SNR tensor is zero
|
||||
if torch.any(snr == 0):
|
||||
snr[snr == 0] = minimal_value
|
||||
return snr
|
||||
|
||||
def load_train_json_from_file(args, report_load = False):
|
||||
try:
|
||||
if report_load:
|
||||
print(f"Loading training config from {args.config}.")
|
||||
|
||||
with open(args.config, 'rt') as f:
|
||||
read_json = json.load(f)
|
||||
|
||||
args.__dict__.update(read_json)
|
||||
except Exception as config_read:
|
||||
print(f"Error on loading training config from {args.config}.")
|
||||
|
||||
def main(args):
|
||||
"""
|
||||
|
@ -384,57 +570,28 @@ def main(args):
|
|||
device = 'cpu'
|
||||
gpu = None
|
||||
|
||||
|
||||
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
|
||||
|
||||
if not os.path.exists(log_folder):
|
||||
os.makedirs(log_folder)
|
||||
|
||||
@torch.no_grad()
|
||||
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
|
||||
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
|
||||
"""
|
||||
Save the model to disk
|
||||
"""
|
||||
global global_step
|
||||
if global_step is None or global_step == 0:
|
||||
logging.warning(" No model to save, something likely blew up on startup, not saving")
|
||||
return
|
||||
logging.info(f" * Saving diffusers model to {save_path}")
|
||||
pipeline = StableDiffusionPipeline(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # save vram
|
||||
requires_safety_checker=None, # avoid nag
|
||||
feature_extractor=None, # must be none of no safety checker
|
||||
)
|
||||
pipeline.save_pretrained(save_path)
|
||||
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
|
||||
def release_memory(model_to_delete, original_device):
|
||||
del model_to_delete
|
||||
gc.collect()
|
||||
|
||||
if save_ckpt:
|
||||
if save_ckpt_dir is not None:
|
||||
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
|
||||
else:
|
||||
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
|
||||
save_ckpt_dir = os.curdir
|
||||
if 'cuda' in original_device.type:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
half = not save_full_precision
|
||||
|
||||
logging.info(f" * Saving SD model to {sd_ckpt_full}")
|
||||
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
|
||||
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_strength_target != None)
|
||||
ema_model_loaded_from_file = False
|
||||
|
||||
if yaml_name and yaml_name != "v1-inference.yaml":
|
||||
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
|
||||
logging.info(f" * Saving yaml to {yaml_save_path}")
|
||||
shutil.copyfile(yaml_name, yaml_save_path)
|
||||
|
||||
if save_optimizer_flag:
|
||||
logging.info(f" Saving optimizer state to {save_path}")
|
||||
ed_optimizer.save(save_path)
|
||||
if use_ema_dacay_training:
|
||||
ema_device = torch.device(args.ema_device)
|
||||
|
||||
optimizer_state_path = None
|
||||
|
||||
try:
|
||||
# check for a local file
|
||||
hf_cache_path = get_hf_ckpt_cache_path(args.resume_ckpt)
|
||||
|
@ -443,10 +600,6 @@ def main(args):
|
|||
text_encoder = CLIPTextModel.from_pretrained(model_root_folder, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(model_root_folder, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(model_root_folder, subfolder="unet")
|
||||
|
||||
optimizer_state_path = os.path.join(args.resume_ckpt, "optimizer.pt")
|
||||
if not os.path.exists(optimizer_state_path):
|
||||
optimizer_state_path = None
|
||||
else:
|
||||
# try to download from HF using resume_ckpt as a repo id
|
||||
downloaded = try_download_model_from_hf(repo_id=args.resume_ckpt)
|
||||
|
@ -457,16 +610,52 @@ def main(args):
|
|||
vae = pipe.vae
|
||||
unet = pipe.unet
|
||||
del pipe
|
||||
|
||||
if args.zero_frequency_noise_ratio == -1.0:
|
||||
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
|
||||
|
||||
if use_ema_dacay_training and args.ema_resume_model:
|
||||
print(f"Loading EMA model: {args.ema_resume_model}")
|
||||
ema_model_loaded_from_file=True
|
||||
hf_cache_path = get_hf_ckpt_cache_path(args.ema_resume_model)
|
||||
|
||||
if os.path.exists(hf_cache_path) or os.path.exists(args.ema_resume_model):
|
||||
ema_model_root_folder, ema_is_sd1attn, ema_yaml = convert_to_hf(args.resume_ckpt)
|
||||
text_encoder_ema = CLIPTextModel.from_pretrained(ema_model_root_folder, subfolder="text_encoder")
|
||||
unet_ema = UNet2DConditionModel.from_pretrained(ema_model_root_folder, subfolder="unet")
|
||||
|
||||
else:
|
||||
# try to download from HF using ema_resume_model as a repo id
|
||||
ema_downloaded = try_download_model_from_hf(repo_id=args.ema_resume_model)
|
||||
if ema_downloaded is None:
|
||||
raise ValueError(
|
||||
f"No local file/folder for ema_resume_model {args.ema_resume_model}, and no matching huggingface.co repo could be downloaded")
|
||||
ema_pipe, ema_model_root_folder, ema_is_sd1attn, ema_yaml = ema_downloaded
|
||||
text_encoder_ema = ema_pipe.text_encoder
|
||||
unet_ema = ema_pipe.unet
|
||||
del ema_pipe
|
||||
|
||||
# Make sure EMA model is on proper device, and memory released if moved
|
||||
unet_ema_current_device = next(unet_ema.parameters()).device
|
||||
if ema_device != unet_ema_current_device:
|
||||
unet_ema_on_wrong_device = unet_ema
|
||||
unet_ema = unet_ema.to(ema_device)
|
||||
release_memory(unet_ema_on_wrong_device, unet_ema_current_device)
|
||||
|
||||
# Make sure EMA model is on proper device, and memory released if moved
|
||||
text_encoder_ema_current_device = next(text_encoder_ema.parameters()).device
|
||||
if ema_device != text_encoder_ema_current_device:
|
||||
text_encoder_ema_on_wrong_device = text_encoder_ema
|
||||
text_encoder_ema = text_encoder_ema.to(ema_device)
|
||||
release_memory(text_encoder_ema_on_wrong_device, text_encoder_ema_current_device)
|
||||
|
||||
|
||||
if args.enable_zero_terminal_snr:
|
||||
# Use zero terminal SNR
|
||||
from utils.unet_utils import enforce_zero_terminal_snr
|
||||
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
|
||||
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
else:
|
||||
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
|
||||
|
@ -503,6 +692,32 @@ def main(args):
|
|||
else:
|
||||
text_encoder = text_encoder.to(device, dtype=torch.float32)
|
||||
|
||||
|
||||
|
||||
|
||||
if use_ema_dacay_training:
|
||||
if not ema_model_loaded_from_file:
|
||||
logging.info(f"EMA decay enabled, creating EMA model.")
|
||||
|
||||
with torch.no_grad():
|
||||
if args.ema_device == device:
|
||||
unet_ema = deepcopy(unet)
|
||||
text_encoder_ema = deepcopy(text_encoder)
|
||||
else:
|
||||
unet_ema_first = deepcopy(unet)
|
||||
text_encoder_ema_first = deepcopy(text_encoder)
|
||||
unet_ema = unet_ema_first.to(ema_device, dtype=unet.dtype)
|
||||
text_encoder_ema = text_encoder_ema_first.to(ema_device, dtype=text_encoder.dtype)
|
||||
del unet_ema_first
|
||||
del text_encoder_ema_first
|
||||
else:
|
||||
# Make sure correct types are used for models
|
||||
unet_ema = unet_ema.to(ema_device, dtype=unet.dtype)
|
||||
text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype)
|
||||
else:
|
||||
unet_ema = None
|
||||
text_encoder_ema = None
|
||||
|
||||
try:
|
||||
#unet = torch.compile(unet)
|
||||
#text_encoder = torch.compile(text_encoder)
|
||||
|
@ -566,6 +781,7 @@ def main(args):
|
|||
tokenizer=tokenizer,
|
||||
seed = seed,
|
||||
shuffle_tags=args.shuffle_tags,
|
||||
keep_tags=args.keep_tags,
|
||||
rated_dataset=args.rated_dataset,
|
||||
rated_dataset_dropout_target=(1.0 - (args.rated_dataset_target_dropout_percent / 100.0))
|
||||
)
|
||||
|
@ -574,6 +790,20 @@ def main(args):
|
|||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
|
||||
if use_ema_dacay_training:
|
||||
args.ema_update_interval = args.ema_update_interval * args.grad_accum
|
||||
if args.ema_strength_target != None:
|
||||
total_number_of_steps: float = epoch_len * args.max_epochs
|
||||
total_number_of_ema_update: float = total_number_of_steps / args.ema_update_interval
|
||||
args.ema_decay_rate = args.ema_strength_target ** (1 / total_number_of_ema_update)
|
||||
|
||||
logging.info(f"ema_strength_target is {args.ema_strength_target}, calculated ema_decay_rate will be: {args.ema_decay_rate}.")
|
||||
|
||||
logging.info(
|
||||
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_update_interval: {args.ema_update_interval}, ema_device: {args.ema_device}.")
|
||||
|
||||
|
||||
ed_optimizer = EveryDreamOptimizer(args,
|
||||
optimizer_config,
|
||||
text_encoder,
|
||||
|
@ -588,7 +818,8 @@ def main(args):
|
|||
batch_size=max(1,args.batch_size//2),
|
||||
default_sample_steps=args.sample_steps,
|
||||
use_xformers=is_xformers_available() and not args.disable_xformers,
|
||||
use_penultimate_clip_layer=(args.clip_skip >= 2)
|
||||
use_penultimate_clip_layer=(args.clip_skip >= 2),
|
||||
guidance_rescale=0.7 if args.enable_zero_terminal_snr else 0
|
||||
)
|
||||
|
||||
"""
|
||||
|
@ -620,7 +851,9 @@ def main(args):
|
|||
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
|
||||
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
|
||||
time.sleep(2) # give opportunity to ctrl-C again to cancel save
|
||||
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
save_model(interrupted_checkpoint_path, global_step=global_step, ed_state=make_current_ed_state(),
|
||||
save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision,
|
||||
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
exit(_SIGTERM_EXIT_CODE)
|
||||
else:
|
||||
# non-main threads (i.e. dataloader workers) should exit cleanly
|
||||
|
@ -668,7 +901,7 @@ def main(args):
|
|||
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
|
||||
|
||||
# actual prediction function - shared between train and validate
|
||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0):
|
||||
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False):
|
||||
with torch.no_grad():
|
||||
with autocast(enabled=args.amp):
|
||||
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
|
||||
|
@ -676,12 +909,13 @@ def main(args):
|
|||
del pixel_values
|
||||
latents = latents[0].sample() * 0.18215
|
||||
|
||||
if zero_frequency_noise_ratio > 0.0:
|
||||
if zero_frequency_noise_ratio != None:
|
||||
if zero_frequency_noise_ratio < 0:
|
||||
zero_frequency_noise_ratio = 0
|
||||
|
||||
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
|
||||
noise = torch.randn_like(latents) + zero_frequency_noise
|
||||
else:
|
||||
noise = torch.randn_like(latents)
|
||||
|
||||
bsz = latents.shape[0]
|
||||
|
||||
|
@ -712,9 +946,35 @@ def main(args):
|
|||
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
return model_pred, target
|
||||
if return_loss:
|
||||
if args.min_snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
else:
|
||||
snr = compute_snr(timesteps, noise_scheduler)
|
||||
|
||||
mse_loss_weights = (
|
||||
torch.stack(
|
||||
[snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1
|
||||
).min(dim=1)[0]
|
||||
/ snr
|
||||
)
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
return model_pred, target, loss
|
||||
|
||||
else:
|
||||
return model_pred, target
|
||||
|
||||
def generate_samples(global_step: int, batch):
|
||||
nonlocal unet
|
||||
nonlocal text_encoder
|
||||
nonlocal unet_ema
|
||||
nonlocal text_encoder_ema
|
||||
|
||||
with isolate_rng():
|
||||
prev_sample_steps = sample_generator.sample_steps
|
||||
sample_generator.reload_config()
|
||||
|
@ -723,19 +983,78 @@ def main(args):
|
|||
print(f" * SampleGenerator config changed, now generating images samples every " +
|
||||
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
|
||||
sample_generator.update_random_captions(batch["captions"])
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=reference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step)
|
||||
|
||||
del inference_pipe
|
||||
gc.collect()
|
||||
models_info = []
|
||||
|
||||
if (args.ema_decay_rate is None) or args.ema_sample_nonema_model:
|
||||
models_info.append({"is_ema": False, "swap_required": False})
|
||||
|
||||
if (args.ema_decay_rate is not None) and args.ema_sample_ema_model:
|
||||
models_info.append({"is_ema": True, "swap_required": ema_device != device})
|
||||
|
||||
for model_info in models_info:
|
||||
|
||||
extra_info: str = ""
|
||||
|
||||
if model_info["is_ema"]:
|
||||
current_unet, current_text_encoder = unet_ema, text_encoder_ema
|
||||
extra_info = "_ema"
|
||||
else:
|
||||
current_unet, current_text_encoder = unet, text_encoder
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
if model_info["swap_required"]:
|
||||
with torch.no_grad():
|
||||
unet_unloaded = unet.to(ema_device)
|
||||
del unet
|
||||
text_encoder_unloaded = text_encoder.to(ema_device)
|
||||
del text_encoder
|
||||
|
||||
current_unet = unet_ema.to(device)
|
||||
del unet_ema
|
||||
current_text_encoder = text_encoder_ema.to(device)
|
||||
del text_encoder_ema
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
|
||||
inference_pipe = sample_generator.create_inference_pipe(unet=current_unet,
|
||||
text_encoder=current_text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
diffusers_scheduler_config=inference_scheduler.config
|
||||
).to(device)
|
||||
sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info)
|
||||
|
||||
# Cleanup
|
||||
del inference_pipe
|
||||
|
||||
if model_info["swap_required"]:
|
||||
with torch.no_grad():
|
||||
unet = unet_unloaded.to(device)
|
||||
del unet_unloaded
|
||||
text_encoder = text_encoder_unloaded.to(device)
|
||||
del text_encoder_unloaded
|
||||
|
||||
unet_ema = current_unet.to(ema_device)
|
||||
del current_unet
|
||||
text_encoder_ema = current_text_encoder.to(ema_device)
|
||||
del current_text_encoder
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def make_save_path(epoch, global_step, prepend=""):
|
||||
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
|
||||
basename = f"{prepend}{args.project_name}"
|
||||
if epoch is not None:
|
||||
basename += f"-ep{epoch:02}"
|
||||
if global_step is not None:
|
||||
basename += f"-gs{global_step:05}"
|
||||
return os.path.join(log_folder, "ckpts", basename)
|
||||
|
||||
|
||||
# Pre-train validation to establish a starting point on the loss graph
|
||||
if validator:
|
||||
|
@ -753,25 +1072,47 @@ def main(args):
|
|||
else:
|
||||
logging.info("No plugins specified")
|
||||
plugins = []
|
||||
|
||||
|
||||
from plugins.plugins import PluginRunner
|
||||
plugin_runner = PluginRunner(plugins=plugins)
|
||||
|
||||
def make_current_ed_state() -> EveryDreamTrainingState:
|
||||
return EveryDreamTrainingState(optimizer=ed_optimizer,
|
||||
train_batch=train_batch,
|
||||
unet=unet,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=noise_scheduler,
|
||||
vae=vae,
|
||||
unet_ema=unet_ema,
|
||||
text_encoder_ema=text_encoder_ema)
|
||||
|
||||
epoch = None
|
||||
try:
|
||||
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
|
||||
plugin_runner.run_on_training_start(log_folder=log_folder, project_name=args.project_name)
|
||||
|
||||
for epoch in range(args.max_epochs):
|
||||
plugin_runner.run_on_epoch_start(epoch=epoch,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
data_root=args.data_root)
|
||||
|
||||
if args.load_settings_every_epoch:
|
||||
load_train_json_from_file(args)
|
||||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
|
||||
plugin_runner.run_on_epoch_start(
|
||||
epoch=epoch,
|
||||
global_step=global_step,
|
||||
epoch_length=epoch_len,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
data_root=args.data_root
|
||||
)
|
||||
|
||||
|
||||
loss_epoch = []
|
||||
epoch_start_time = time.time()
|
||||
images_per_sec_log_step = []
|
||||
|
||||
epoch_len = math.ceil(len(train_batch) / args.batch_size)
|
||||
steps_pbar = tqdm(range(epoch_len), position=1, leave=False, dynamic_ncols=True)
|
||||
steps_pbar.set_description(f"{Fore.LIGHTCYAN_EX}Steps{Style.RESET_ALL}")
|
||||
|
||||
|
@ -781,16 +1122,18 @@ def main(args):
|
|||
)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
|
||||
step_start_time = time.time()
|
||||
|
||||
plugin_runner.run_on_step_start(epoch=epoch,
|
||||
local_step=step,
|
||||
global_step=global_step,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
batch=batch)
|
||||
batch=batch,
|
||||
ed_state=make_current_ed_state())
|
||||
|
||||
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True)
|
||||
|
||||
del target, model_pred
|
||||
|
||||
|
@ -800,6 +1143,21 @@ def main(args):
|
|||
|
||||
ed_optimizer.step(loss, step, global_step)
|
||||
|
||||
if args.ema_decay_rate != None:
|
||||
if ((global_step + 1) % args.ema_update_interval) == 0:
|
||||
# debug_start_time = time.time() # Measure time
|
||||
|
||||
if args.disable_unet_training != True:
|
||||
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
|
||||
|
||||
if args.disable_textenc_training != True:
|
||||
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
|
||||
|
||||
# debug_end_time = time.time() # Measure time
|
||||
# debug_elapsed_time = debug_end_time - debug_start_time # Measure time
|
||||
# print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.") # Measure time
|
||||
|
||||
|
||||
loss_step = loss.detach().item()
|
||||
|
||||
steps_pbar.set_postfix({"loss/step": loss_step}, {"gs": global_step})
|
||||
|
@ -816,7 +1174,7 @@ def main(args):
|
|||
lr_unet = ed_optimizer.get_unet_lr()
|
||||
lr_textenc = ed_optimizer.get_textenc_lr()
|
||||
loss_log_step = []
|
||||
|
||||
|
||||
log_writer.add_scalar(tag="hyperparameter/lr unet", scalar_value=lr_unet, global_step=global_step)
|
||||
log_writer.add_scalar(tag="hyperparameter/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
|
||||
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_step, global_step=global_step)
|
||||
|
@ -840,23 +1198,29 @@ def main(args):
|
|||
|
||||
min_since_last_ckpt = (time.time() - last_epoch_saved_time) / 60
|
||||
|
||||
needs_save = False
|
||||
if args.ckpt_every_n_minutes is not None and (min_since_last_ckpt > args.ckpt_every_n_minutes):
|
||||
last_epoch_saved_time = time.time()
|
||||
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
|
||||
save_path = make_save_path(epoch, global_step)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
|
||||
needs_save = True
|
||||
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs and epoch >= args.save_ckpts_from_n_epochs:
|
||||
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
|
||||
needs_save = True
|
||||
if needs_save:
|
||||
save_path = make_save_path(epoch, global_step)
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(),
|
||||
save_ckpt_dir=args.save_ckpt_dir, yaml_name=None,
|
||||
save_full_precision=args.save_full_precision,
|
||||
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
plugin_runner.run_on_step_end(epoch=epoch,
|
||||
global_step=global_step,
|
||||
local_step=step,
|
||||
project_name=args.project_name,
|
||||
log_folder=log_folder,
|
||||
data_root=args.data_root,
|
||||
batch=batch)
|
||||
batch=batch,
|
||||
ed_state=make_current_ed_state())
|
||||
|
||||
del batch
|
||||
global_step += 1
|
||||
|
@ -873,8 +1237,9 @@ def main(args):
|
|||
train_batch.shuffle(epoch_n=epoch, max_epochs = args.max_epochs)
|
||||
write_batch_schedule(args, log_folder, train_batch, epoch + 1)
|
||||
|
||||
loss_epoch = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
|
||||
if len(loss_epoch) > 0:
|
||||
loss_epoch = sum(loss_epoch) / len(loss_epoch)
|
||||
log_writer.add_scalar(tag="loss/epoch", scalar_value=loss_epoch, global_step=global_step)
|
||||
|
||||
plugin_runner.run_on_epoch_end(epoch=epoch,
|
||||
global_step=global_step,
|
||||
|
@ -882,13 +1247,18 @@ def main(args):
|
|||
log_folder=log_folder,
|
||||
data_root=args.data_root)
|
||||
|
||||
gc.collect()
|
||||
gc.collect()
|
||||
# end of epoch
|
||||
|
||||
# end of training
|
||||
epoch = args.max_epochs
|
||||
|
||||
plugin_runner.run_on_training_end()
|
||||
|
||||
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(),
|
||||
save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision,
|
||||
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
|
||||
total_elapsed_time = time.time() - training_start_time
|
||||
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
|
||||
|
@ -898,7 +1268,9 @@ def main(args):
|
|||
except Exception as ex:
|
||||
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
|
||||
save_path = make_save_path(epoch, global_step, prepend="errored-")
|
||||
__save_model(save_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
save_model(save_path, global_step=global_step, ed_state=make_current_ed_state(),
|
||||
save_ckpt_dir=args.save_ckpt_dir, yaml_name=yaml, save_full_precision=args.save_full_precision,
|
||||
save_optimizer_flag=args.save_optimizer, save_ckpt=not args.no_save_ckpt)
|
||||
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
|
||||
raise ex
|
||||
|
||||
|
@ -914,14 +1286,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
|
||||
args, argv = argparser.parse_known_args()
|
||||
|
||||
if args.config is not None:
|
||||
print(f"Loading training config from {args.config}.")
|
||||
with open(args.config, 'rt') as f:
|
||||
args.__dict__.update(json.load(f))
|
||||
if len(argv) > 0:
|
||||
print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.")
|
||||
else:
|
||||
print("No config file specified, using command line args")
|
||||
load_train_json_from_file(args, report_load=True)
|
||||
|
||||
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
|
||||
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
|
||||
|
@ -936,7 +1301,7 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--disable_unet_training", action="store_true", default=False, help="disables training of unet (def: False) NOT RECOMMENDED")
|
||||
argparser.add_argument("--disable_xformers", action="store_true", default=False, help="disable xformers, may reduce performance (def: False)")
|
||||
argparser.add_argument("--flip_p", type=float, default=0.0, help="probability of flipping image horizontally (def: 0.0) use 0.0 to 1.0, ex 0.5, not good for specific faces!")
|
||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1)")
|
||||
argparser.add_argument("--gpuid", type=int, default=0, help="id of gpu to use for training, (def: 0) (ex: 1 to use GPU_ID 1), use nvidia-smi to find your GPU ids")
|
||||
argparser.add_argument("--gradient_checkpointing", action="store_true", default=False, help="enable gradient checkpointing to reduce VRAM use, may reduce performance (def: False)")
|
||||
argparser.add_argument("--grad_accum", type=int, default=1, help="Gradient accumulation factor (def: 1), (ex, 2)")
|
||||
argparser.add_argument("--logdir", type=str, default="logs", help="folder to save logs to (def: logs)")
|
||||
|
@ -964,15 +1329,27 @@ if __name__ == "__main__":
|
|||
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
|
||||
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
|
||||
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
|
||||
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, def: 0 (shuffle all)")
|
||||
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
|
||||
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")
|
||||
argparser.add_argument("--validation_config", default=None, help="Path to a JSON configuration file for the validator. Default is no validation.")
|
||||
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
|
||||
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
|
||||
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
|
||||
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15, set to -1 to use zero terminal SNR noising beta schedule instead")
|
||||
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
|
||||
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
|
||||
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")
|
||||
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameter is the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
|
||||
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_rate) from training, and the ema_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
|
||||
argparser.add_argument("--ema_strength_target", type=float, default=None, help="EMA decay target value in range (0,1). emarate will be calculated from equation: 'ema_decay_rate=ema_strength_target^(total_steps/ema_update_interval)'. Using this parameter will enable the ema feature and overide ema_decay_rate.")
|
||||
argparser.add_argument("--ema_update_interval", type=int, default=500, help="How many steps between optimizer steps that EMA decay updates. EMA model will be update on every step modulo grad_accum times ema_update_interval.")
|
||||
argparser.add_argument("--ema_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using 'cpu' is taking around 4 seconds per update vs fraction of a second on 'cuda'. Using 'cuda' will reserve around 3.2GB VRAM for a model, with 'cpu' the system RAM will be used.")
|
||||
argparser.add_argument("--ema_sample_nonema_model", action="store_true", default=False, help="Will show samples from non-EMA trained model, just like regular training. Can be used with: --ema_sample_ema_model")
|
||||
argparser.add_argument("--ema_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. May be slower when using ema cpu offloading. Can be used with: --ema_sample_nonema_model")
|
||||
argparser.add_argument("--ema_resume_model", type=str, default=None, help="The EMA decay checkpoint to resume from for EMA decay, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1-ema-decay")
|
||||
|
||||
|
||||
# load CLI args to overwrite existing config args
|
||||
args = argparser.parse_args(args=argv, namespace=args)
|
||||
|
||||
|
||||
main(args)
|
||||
|
|
|
@ -39,5 +39,15 @@
|
|||
"write_schedule": false,
|
||||
"rated_dataset": false,
|
||||
"rated_dataset_target_dropout_percent": 50,
|
||||
"zero_frequency_noise_ratio": 0.02
|
||||
"zero_frequency_noise_ratio": 0.02,
|
||||
"enable_zero_terminal_snr": false,
|
||||
"load_settings_every_epoch": false,
|
||||
"min_snr_gamma": null,
|
||||
"ema_decay_rate": null,
|
||||
"ema_strength_target": null,
|
||||
"ema_update_interval": null,
|
||||
"ema_device": null,
|
||||
"ema_sample_nonema_model": false,
|
||||
"ema_sample_ema_model": false,
|
||||
"ema_resume_model" : null
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
import os.path as osp
|
||||
import re
|
||||
from safetensors import safe_open
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -288,17 +289,38 @@ def convert(model_path: str, checkpoint_path: str, half: bool):
|
|||
text_enc_path = osp.join(model_path, "text_encoder", "pytorch_model.bin")
|
||||
|
||||
# Convert the UNet model
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
if osp.exists(unet_path):
|
||||
unet_state_dict = torch.load(unet_path, map_location="cpu")
|
||||
else:
|
||||
unet_state_dict = {}
|
||||
unet_path = osp.join(model_path, "unet", "diffusion_pytorch_model.safetensors")
|
||||
with safe_open(unet_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
unet_state_dict[key] = f.get_tensor(key)
|
||||
unet_state_dict = convert_unet_state_dict(unet_state_dict)
|
||||
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
|
||||
|
||||
# Convert the VAE model
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
if osp.exists(vae_path):
|
||||
vae_state_dict = torch.load(vae_path, map_location="cpu")
|
||||
else:
|
||||
vae_state_dict = {}
|
||||
vae_path = osp.join(model_path, "vae", "diffusion_pytorch_model.safetensors")
|
||||
with safe_open(vae_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
vae_state_dict[key] = f.get_tensor(key)
|
||||
vae_state_dict = convert_vae_state_dict(vae_state_dict)
|
||||
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
|
||||
|
||||
# Convert the text encoder model
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
if osp.exists(text_enc_path):
|
||||
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
|
||||
else:
|
||||
text_enc_dict = {}
|
||||
text_enc_path = osp.join(model_path, "text_encoder", "model.safetensors")
|
||||
with safe_open(text_enc_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
text_enc_dict[key] = f.get_tensor(key)
|
||||
|
||||
# Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper
|
||||
is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict
|
||||
|
|
|
@ -90,7 +90,8 @@ class SampleGenerator:
|
|||
default_seed: int,
|
||||
default_sample_steps: int,
|
||||
use_xformers: bool,
|
||||
use_penultimate_clip_layer: bool):
|
||||
use_penultimate_clip_layer: bool,
|
||||
guidance_rescale: float = 0):
|
||||
self.log_folder = log_folder
|
||||
self.log_writer = log_writer
|
||||
self.batch_size = batch_size
|
||||
|
@ -99,6 +100,7 @@ class SampleGenerator:
|
|||
self.show_progress_bars = False
|
||||
self.generate_pretrain_samples = False
|
||||
self.use_penultimate_clip_layer = use_penultimate_clip_layer
|
||||
self.guidance_rescale = guidance_rescale
|
||||
|
||||
self.default_resolution = default_resolution
|
||||
self.default_seed = default_seed
|
||||
|
@ -182,7 +184,7 @@ class SampleGenerator:
|
|||
self.sample_requests = self._make_random_caption_sample_requests()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
|
||||
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, extra_info: str = ""):
|
||||
"""
|
||||
generates samples at different cfg scales and saves them to disk
|
||||
"""
|
||||
|
@ -231,6 +233,7 @@ class SampleGenerator:
|
|||
generator=generators,
|
||||
width=size[0],
|
||||
height=size[1],
|
||||
guidance_rescale=self.guidance_rescale
|
||||
).images
|
||||
|
||||
for image in images:
|
||||
|
@ -269,15 +272,15 @@ class SampleGenerator:
|
|||
prompt = prompts[prompt_idx]
|
||||
clean_prompt = clean_filename(prompt)
|
||||
|
||||
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
|
||||
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
|
||||
f.write(str(batch[prompt_idx]))
|
||||
|
||||
tfimage = transforms.ToTensor()(result)
|
||||
if batch[prompt_idx].wants_random_caption:
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}{extra_info}", img_tensor=tfimage, global_step=global_step)
|
||||
else:
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
|
||||
self.log_writer.add_image(tag=f"sample_{sample_index}_{extra_info}{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
|
||||
sample_index += 1
|
||||
|
||||
del result
|
||||
|
|
|
@ -7,7 +7,7 @@ pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url "http
|
|||
pip install -U transformers==4.29.2
|
||||
pip install -U diffusers[torch]==0.18.0
|
||||
pip install pynvml==11.4.1
|
||||
pip install -U https://github.com/victorchall/everydream-whls/raw/main/bitsandbytes-0.38.1-py2.py3-none-any.whl
|
||||
pip install -U pip install -U https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.41.1-py3-none-win_amd64.whl
|
||||
pip install ftfy==6.1.1
|
||||
pip install aiohttp==3.8.4
|
||||
pip install tensorboard>=2.11.0
|
||||
|
@ -23,6 +23,7 @@ pip install compel~=1.1.3
|
|||
pip install dadaptation
|
||||
pip install safetensors
|
||||
pip install open-flamingo==2.0.0
|
||||
pip install prodigyopt
|
||||
python utils/get_yamls.py
|
||||
GOTO :eof
|
||||
|
||||
|
|
Loading…
Reference in New Issue