Add original files
|
@ -0,0 +1,4 @@
|
||||||
|
./venv
|
||||||
|
./danbooru-aesthetic
|
||||||
|
./logs
|
||||||
|
*.ckpt
|
|
@ -1,4 +1,24 @@
|
||||||
# ---> Python
|
# OS-generated
|
||||||
|
# ------------
|
||||||
|
.DS_Store*
|
||||||
|
[Tt]humbs.db
|
||||||
|
[Dd]esktop.ini
|
||||||
|
|
||||||
|
# Programming - general
|
||||||
|
*.log
|
||||||
|
example.png
|
||||||
|
scores.json
|
||||||
|
danbooru-aesthetic
|
||||||
|
logs
|
||||||
|
|
||||||
|
# =========================================================================== #
|
||||||
|
# Python-related
|
||||||
|
# =========================================================================== #
|
||||||
|
# src: https://github.com/github/gitignore/blob/master/Python.gitignore
|
||||||
|
|
||||||
|
# JetBrains PyCharm / Rider
|
||||||
|
.idea/
|
||||||
|
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
|
@ -20,6 +40,7 @@ lib64/
|
||||||
parts/
|
parts/
|
||||||
sdist/
|
sdist/
|
||||||
var/
|
var/
|
||||||
|
venv/
|
||||||
wheels/
|
wheels/
|
||||||
share/python-wheels/
|
share/python-wheels/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
|
@ -27,114 +48,11 @@ share/python-wheels/
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
|
|
||||||
# PyInstaller
|
|
||||||
# Usually these files are written by a python script from a template
|
|
||||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
||||||
*.manifest
|
|
||||||
*.spec
|
|
||||||
|
|
||||||
# Installer logs
|
# =========================================================================== #
|
||||||
pip-log.txt
|
# Repo-specific
|
||||||
pip-delete-this-directory.txt
|
# =========================================================================== #
|
||||||
|
/src/
|
||||||
# Unit test / coverage reports
|
|
||||||
htmlcov/
|
|
||||||
.tox/
|
|
||||||
.nox/
|
|
||||||
.coverage
|
|
||||||
.coverage.*
|
|
||||||
.cache
|
|
||||||
nosetests.xml
|
|
||||||
coverage.xml
|
|
||||||
*.cover
|
|
||||||
*.py,cover
|
|
||||||
.hypothesis/
|
|
||||||
.pytest_cache/
|
|
||||||
cover/
|
|
||||||
|
|
||||||
# Translations
|
|
||||||
*.mo
|
|
||||||
*.pot
|
|
||||||
|
|
||||||
# Django stuff:
|
|
||||||
*.log
|
|
||||||
local_settings.py
|
|
||||||
db.sqlite3
|
|
||||||
db.sqlite3-journal
|
|
||||||
|
|
||||||
# Flask stuff:
|
|
||||||
instance/
|
|
||||||
.webassets-cache
|
|
||||||
|
|
||||||
# Scrapy stuff:
|
|
||||||
.scrapy
|
|
||||||
|
|
||||||
# Sphinx documentation
|
|
||||||
docs/_build/
|
|
||||||
|
|
||||||
# PyBuilder
|
|
||||||
.pybuilder/
|
|
||||||
target/
|
|
||||||
|
|
||||||
# Jupyter Notebook
|
|
||||||
.ipynb_checkpoints
|
|
||||||
|
|
||||||
# IPython
|
|
||||||
profile_default/
|
|
||||||
ipython_config.py
|
|
||||||
|
|
||||||
# pyenv
|
|
||||||
# For a library or package, you might want to ignore these files since the code is
|
|
||||||
# intended to run in multiple environments; otherwise, check them in:
|
|
||||||
# .python-version
|
|
||||||
|
|
||||||
# pipenv
|
|
||||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
||||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
||||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
||||||
# install all needed dependencies.
|
|
||||||
#Pipfile.lock
|
|
||||||
|
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
|
||||||
__pypackages__/
|
|
||||||
|
|
||||||
# Celery stuff
|
|
||||||
celerybeat-schedule
|
|
||||||
celerybeat.pid
|
|
||||||
|
|
||||||
# SageMath parsed files
|
|
||||||
*.sage.py
|
|
||||||
|
|
||||||
# Environments
|
|
||||||
.env
|
|
||||||
.venv
|
|
||||||
env/
|
|
||||||
venv/
|
|
||||||
ENV/
|
|
||||||
env.bak/
|
|
||||||
venv.bak/
|
|
||||||
|
|
||||||
# Spyder project settings
|
|
||||||
.spyderproject
|
|
||||||
.spyproject
|
|
||||||
|
|
||||||
# Rope project settings
|
|
||||||
.ropeproject
|
|
||||||
|
|
||||||
# mkdocs documentation
|
|
||||||
/site
|
|
||||||
|
|
||||||
# mypy
|
|
||||||
.mypy_cache/
|
|
||||||
.dmypy.json
|
|
||||||
dmypy.json
|
|
||||||
|
|
||||||
# Pyre type checker
|
|
||||||
.pyre/
|
|
||||||
|
|
||||||
# pytype static type analyzer
|
|
||||||
.pytype/
|
|
||||||
|
|
||||||
# Cython debug symbols
|
|
||||||
cython_debug/
|
|
||||||
|
|
||||||
|
#Obsidian
|
||||||
|
.obsidian/
|
||||||
|
|
|
@ -0,0 +1,10 @@
|
||||||
|
FROM pytorch/pytorch:latest
|
||||||
|
|
||||||
|
RUN apt update && \
|
||||||
|
apt install -y git curl unzip vim && \
|
||||||
|
pip install git+https://github.com/derfred/lightning.git@waifu-1.6.0#egg=pytorch-lightning
|
||||||
|
RUN mkdir /waifu
|
||||||
|
COPY . /waifu/
|
||||||
|
WORKDIR /waifu
|
||||||
|
RUN grep -v pytorch-lightning requirements.txt > requirements-waifu.txt && \
|
||||||
|
pip install -r requirements-waifu.txt
|
19
LICENSE
|
@ -1,9 +1,14 @@
|
||||||
MIT License
|
All rights reserved by the authors.
|
||||||
|
You must not distribute the weights provided to you directly or indirectly without explicit consent of the authors.
|
||||||
|
You must not distribute harmful, offensive, dehumanizing content or otherwise harmful representations of people or their environments, cultures, religions, etc. produced with the model weights
|
||||||
|
or other generated content described in the "Misuse and Malicious Use" section in the model card.
|
||||||
|
The model weights are provided for research purposes only.
|
||||||
|
|
||||||
Copyright (c) <year> <copyright holders>
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
56
README.md
|
@ -1,3 +1,55 @@
|
||||||
# waifu-diffusion-original
|
|
||||||
|
|
||||||
The Waifu Diffusion git before they deleted everything.
|
|
||||||
|
# Waifu Diffusion
|
||||||
|
|
||||||
|
Waifu Diffusion is the name for this project of finetuning Stable Diffusion on images and captions downloaded through Danbooru
|
||||||
|
|
||||||
|
(**Note:** This project has **no affiliation with Danbooru.**)
|
||||||
|
|
||||||
|
<img src=https://cdn.discordapp.com/attachments/872361510133981234/1016022078635388979/unknown.png?3867929 width=40% height=40%>
|
||||||
|
|
||||||
|
<sub>Prompt: touhou 1girl komeiji_koishi portrait</sub>
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
[Index](./docs/en/README.md)
|
||||||
|
|
||||||
|
[Weights](./docs/en/weights/README.md)
|
||||||
|
|
||||||
|
[Training Guide](./docs/en/training/README.md)
|
||||||
|
|
||||||
|
All thanks goes to CompVis and Stability AI for releasing this codebase!
|
||||||
|
|
||||||
|
Model Link: https://huggingface.co/hakurei/waifu-diffusion
|
||||||
|
|
||||||
|
### Any questions? Come hop on by to our Discord server!
|
||||||
|
|
||||||
|
[![Discord Server](https://discordapp.com/api/guilds/930499730843250783/widget.png?style=banner2)](https://discord.gg/Sx6Spmsgx7)
|
||||||
|
|
||||||
|
# Stable Diffusion
|
||||||
|
*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*
|
||||||
|
|
||||||
|
## Comments
|
||||||
|
|
||||||
|
- Our codebase for the diffusion models builds heavily on [OpenAI's ADM codebase](https://github.com/openai/guided-diffusion)
|
||||||
|
and [https://github.com/lucidrains/denoising-diffusion-pytorch](https://github.com/lucidrains/denoising-diffusion-pytorch).
|
||||||
|
Thanks for open-sourcing!
|
||||||
|
|
||||||
|
- The implementation of the transformer encoder is from [x-transformers](https://github.com/lucidrains/x-transformers) by [lucidrains](https://github.com/lucidrains?tab=repositories).
|
||||||
|
|
||||||
|
|
||||||
|
## BibTeX
|
||||||
|
|
||||||
|
```
|
||||||
|
@misc{rombach2021highresolution,
|
||||||
|
title={High-Resolution Image Synthesis with Latent Diffusion Models},
|
||||||
|
author={Robin Rombach and Andreas Blattmann and Dominik Lorenz and Patrick Esser and Björn Ommer},
|
||||||
|
year={2021},
|
||||||
|
eprint={2112.10752},
|
||||||
|
archivePrefix={arXiv},
|
||||||
|
primaryClass={cs.CV}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,140 @@
|
||||||
|
# Stable Diffusion v1 Model Card
|
||||||
|
This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion).
|
||||||
|
|
||||||
|
## Model Details
|
||||||
|
- **Developed by:** Robin Rombach, Patrick Esser
|
||||||
|
- **Model type:** Diffusion-based text-to-image generation model
|
||||||
|
- **Language(s):** English
|
||||||
|
- **License:** [Proprietary](LICENSE)
|
||||||
|
- **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
|
||||||
|
- **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
|
||||||
|
- **Cite as:**
|
||||||
|
|
||||||
|
@InProceedings{Rombach_2022_CVPR,
|
||||||
|
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
||||||
|
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
||||||
|
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
month = {June},
|
||||||
|
year = {2022},
|
||||||
|
pages = {10684-10695}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Uses
|
||||||
|
|
||||||
|
## Direct Use
|
||||||
|
The model is intended for research purposes only. Possible research areas and
|
||||||
|
tasks include
|
||||||
|
|
||||||
|
- Safe deployment of models which have the potential to generate harmful content.
|
||||||
|
- Probing and understanding the limitations and biases of generative models.
|
||||||
|
- Generation of artworks and use in design and other artistic processes.
|
||||||
|
- Applications in educational or creative tools.
|
||||||
|
- Research on generative models.
|
||||||
|
|
||||||
|
Excluded uses are described below.
|
||||||
|
|
||||||
|
### Misuse, Malicious Use, and Out-of-Scope Use
|
||||||
|
_Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
|
||||||
|
|
||||||
|
|
||||||
|
The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
|
||||||
|
#### Out-of-Scope Use
|
||||||
|
The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
|
||||||
|
#### Misuse and Malicious Use
|
||||||
|
Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
|
||||||
|
|
||||||
|
- Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
|
||||||
|
- Intentionally promoting or propagating discriminatory content or harmful stereotypes.
|
||||||
|
- Impersonating individuals without their consent.
|
||||||
|
- Sexual content without consent of the people who might see it.
|
||||||
|
- Mis- and disinformation
|
||||||
|
- Representations of egregious violence and gore
|
||||||
|
- Sharing of copyrighted or licensed material in violation of its terms of use.
|
||||||
|
- Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
|
||||||
|
|
||||||
|
## Limitations and Bias
|
||||||
|
|
||||||
|
### Limitations
|
||||||
|
|
||||||
|
- The model does not achieve perfect photorealism
|
||||||
|
- The model cannot render legible text
|
||||||
|
- The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
|
||||||
|
- Faces and people in general may not be generated properly.
|
||||||
|
- The model was trained mainly with English captions and will not work as well in other languages.
|
||||||
|
- The autoencoding part of the model is lossy
|
||||||
|
- The model was trained on a large-scale dataset
|
||||||
|
[LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
|
||||||
|
and is not fit for product use without additional safety mechanisms and
|
||||||
|
considerations.
|
||||||
|
|
||||||
|
### Bias
|
||||||
|
While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
|
||||||
|
Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
|
||||||
|
which consists of images that are primarily limited to English descriptions.
|
||||||
|
Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
|
||||||
|
This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
|
||||||
|
ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
|
||||||
|
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
**Training Data**
|
||||||
|
The model developers used the following dataset for training the model:
|
||||||
|
|
||||||
|
- LAION-2B (en) and subsets thereof (see next section)
|
||||||
|
|
||||||
|
**Training Procedure**
|
||||||
|
Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
|
||||||
|
|
||||||
|
- Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
|
||||||
|
- Text prompts are encoded through a ViT-L/14 text-encoder.
|
||||||
|
- The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
|
||||||
|
- The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
|
||||||
|
|
||||||
|
We currently provide three checkpoints, `sd-v1-1.ckpt`, `sd-v1-2.ckpt` and `sd-v1-3.ckpt`,
|
||||||
|
which were trained as follows,
|
||||||
|
|
||||||
|
- `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
|
||||||
|
194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
|
||||||
|
- `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
|
||||||
|
515k steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
|
||||||
|
filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
|
||||||
|
- `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-improved-aesthetics" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
|
||||||
|
|
||||||
|
|
||||||
|
- **Hardware:** 32 x 8 x A100 GPUs
|
||||||
|
- **Optimizer:** AdamW
|
||||||
|
- **Gradient Accumulations**: 2
|
||||||
|
- **Batch:** 32 x 8 x 2 x 4 = 2048
|
||||||
|
- **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
|
||||||
|
|
||||||
|
## Evaluation Results
|
||||||
|
Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
|
||||||
|
5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
|
||||||
|
steps show the relative improvements of the checkpoints:
|
||||||
|
|
||||||
|
![pareto](assets/v1-variants-scores.jpg)
|
||||||
|
|
||||||
|
Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
|
||||||
|
## Environmental Impact
|
||||||
|
|
||||||
|
**Stable Diffusion v1** **Estimated Emissions**
|
||||||
|
Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
|
||||||
|
|
||||||
|
- **Hardware Type:** A100 PCIe 40GB
|
||||||
|
- **Hours used:** 150000
|
||||||
|
- **Cloud Provider:** AWS
|
||||||
|
- **Compute Region:** US-east
|
||||||
|
- **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
|
||||||
|
## Citation
|
||||||
|
@InProceedings{Rombach_2022_CVPR,
|
||||||
|
author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
|
||||||
|
title = {High-Resolution Image Synthesis With Latent Diffusion Models},
|
||||||
|
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
||||||
|
month = {June},
|
||||||
|
year = {2022},
|
||||||
|
pages = {10684-10695}
|
||||||
|
}
|
||||||
|
|
||||||
|
*This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
@echo off
|
||||||
|
IF NOT EXIST CONDA umamba create -r conda -f environment.yaml -y
|
||||||
|
call conda\condabin\activate.bat ldm
|
||||||
|
cls
|
||||||
|
|
||||||
|
:PROMPT
|
||||||
|
python scripts/txt2img_gradio.py
|
|
@ -0,0 +1,142 @@
|
||||||
|
import webdataset as wds
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
|
||||||
|
from warnings import filterwarnings
|
||||||
|
|
||||||
|
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1" # choose GPU if you are on a multi GPU server
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch.nn as nn
|
||||||
|
from torchvision import datasets, transforms
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
from os.path import join
|
||||||
|
from datasets import load_dataset
|
||||||
|
import pandas as pd
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
import json
|
||||||
|
|
||||||
|
import clip
|
||||||
|
|
||||||
|
|
||||||
|
from PIL import Image, ImageFile
|
||||||
|
|
||||||
|
|
||||||
|
##### This script will predict the aesthetic score for this image file:
|
||||||
|
|
||||||
|
img_path = "../250k_data-0/img/000baa665498e7a61130d7662f81e698.jpg"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# if you changed the MLP architecture during training, change it also here:
|
||||||
|
class MLP(pl.LightningModule):
|
||||||
|
def __init__(self, input_size, xcol='emb', ycol='avg_rating'):
|
||||||
|
super().__init__()
|
||||||
|
self.input_size = input_size
|
||||||
|
self.xcol = xcol
|
||||||
|
self.ycol = ycol
|
||||||
|
self.layers = nn.Sequential(
|
||||||
|
nn.Linear(self.input_size, 1024),
|
||||||
|
#nn.ReLU(),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.Linear(1024, 128),
|
||||||
|
#nn.ReLU(),
|
||||||
|
nn.Dropout(0.2),
|
||||||
|
nn.Linear(128, 64),
|
||||||
|
#nn.ReLU(),
|
||||||
|
nn.Dropout(0.1),
|
||||||
|
|
||||||
|
nn.Linear(64, 16),
|
||||||
|
#nn.ReLU(),
|
||||||
|
|
||||||
|
nn.Linear(16, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.layers(x)
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx):
|
||||||
|
x = batch[self.xcol]
|
||||||
|
y = batch[self.ycol].reshape(-1, 1)
|
||||||
|
x_hat = self.layers(x)
|
||||||
|
loss = F.mse_loss(x_hat, y)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
x = batch[self.xcol]
|
||||||
|
y = batch[self.ycol].reshape(-1, 1)
|
||||||
|
x_hat = self.layers(x)
|
||||||
|
loss = F.mse_loss(x_hat, y)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def normalized(a, axis=-1, order=2):
|
||||||
|
import numpy as np # pylint: disable=import-outside-toplevel
|
||||||
|
|
||||||
|
l2 = np.atleast_1d(np.linalg.norm(a, order, axis))
|
||||||
|
l2[l2 == 0] = 1
|
||||||
|
return a / np.expand_dims(l2, axis)
|
||||||
|
|
||||||
|
|
||||||
|
model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
|
||||||
|
|
||||||
|
s = torch.load("sac+logos+ava1-l14-linearMSE.pth") # load the model you trained previously or the model available in this repo
|
||||||
|
|
||||||
|
model.load_state_dict(s)
|
||||||
|
|
||||||
|
model.to("cuda")
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
model2, preprocess = clip.load("ViT-L/14", device=device) #RN50x64
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def aesthetic(img_path):
|
||||||
|
pil_image = Image.open(img_path)
|
||||||
|
image = preprocess(pil_image).unsqueeze(0).to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
image_features = model2.encode_image(image)
|
||||||
|
im_emb_arr = normalized(image_features.cpu().detach().numpy())
|
||||||
|
prediction = model(torch.from_numpy(im_emb_arr).to(device).type(torch.cuda.FloatTensor))
|
||||||
|
return prediction.item()
|
||||||
|
|
||||||
|
import json
|
||||||
|
import glob
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
imdir = '../250k_data-0/img/'
|
||||||
|
ext = ['png', 'jpg', 'jpeg', 'bmp']
|
||||||
|
images = []
|
||||||
|
[images.extend(glob.glob(imdir + '*.' + e)) for e in ext]
|
||||||
|
|
||||||
|
aesthetic_scores = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
for i in tqdm.tqdm(images):
|
||||||
|
try:
|
||||||
|
score = aesthetic(i)
|
||||||
|
except:
|
||||||
|
print(f'skipping {i}')
|
||||||
|
continue
|
||||||
|
if score < 5.0:
|
||||||
|
shutil.move(i, i.replace('img', 'nonaesthetic'))
|
||||||
|
elif score > 6.0:
|
||||||
|
shutil.move(i, i.replace('img', 'aesthetic'))
|
||||||
|
aesthetic_scores[i] = score
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
with open('scores.json', 'w') as f:
|
||||||
|
f.write(json.dumps(aesthetic_scores))
|
After Width: | Height: | Size: 651 KiB |
After Width: | Height: | Size: 596 KiB |
After Width: | Height: | Size: 609 KiB |
After Width: | Height: | Size: 548 KiB |
After Width: | Height: | Size: 705 KiB |
After Width: | Height: | Size: 757 KiB |
After Width: | Height: | Size: 612 KiB |
After Width: | Height: | Size: 312 KiB |
After Width: | Height: | Size: 72 KiB |
After Width: | Height: | Size: 319 KiB |
After Width: | Height: | Size: 788 KiB |
After Width: | Height: | Size: 958 KiB |
After Width: | Height: | Size: 9.4 MiB |
After Width: | Height: | Size: 610 KiB |
After Width: | Height: | Size: 643 KiB |
After Width: | Height: | Size: 641 KiB |
After Width: | Height: | Size: 174 KiB |
After Width: | Height: | Size: 1.1 MiB |
After Width: | Height: | Size: 1.3 MiB |
After Width: | Height: | Size: 945 KiB |
After Width: | Height: | Size: 972 KiB |
After Width: | Height: | Size: 2.5 MiB |
After Width: | Height: | Size: 2.5 MiB |
After Width: | Height: | Size: 2.3 MiB |
After Width: | Height: | Size: 662 KiB |
After Width: | Height: | Size: 302 KiB |
After Width: | Height: | Size: 2.2 MiB |
After Width: | Height: | Size: 70 KiB |
|
@ -0,0 +1,54 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 16
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [16]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
|
@ -0,0 +1,53 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 4
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
|
@ -0,0 +1,54 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 3
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
|
@ -0,0 +1,53 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 4.5e-6
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
embed_dim: 64
|
||||||
|
lossconfig:
|
||||||
|
target: ldm.modules.losses.LPIPSWithDiscriminator
|
||||||
|
params:
|
||||||
|
disc_start: 50001
|
||||||
|
kl_weight: 0.000001
|
||||||
|
disc_weight: 0.5
|
||||||
|
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 64
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [16,8]
|
||||||
|
dropout: 0.0
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 12
|
||||||
|
wrap: True
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetSRValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
degradation: pil_nearest
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 1000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: True
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
accumulate_grad_batches: 2
|
|
@ -0,0 +1,86 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 2.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 224
|
||||||
|
attention_resolutions:
|
||||||
|
# note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 64 for f4
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ckpt_path: models/first_stage_models/vq-f4/model.ckpt
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config: __is_unconditional__
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 48
|
||||||
|
num_workers: 5
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: taming.data.faceshq.CelebAHQTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: taming.data.faceshq.CelebAHQValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,98 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: class_label
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 256
|
||||||
|
attention_resolutions:
|
||||||
|
#note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 32 for f8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 512
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
n_embed: 16384
|
||||||
|
ckpt_path: configs/first_stage_models/vq-f8/model.yaml
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions:
|
||||||
|
- 32
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.ClassEmbedder
|
||||||
|
params:
|
||||||
|
embed_dim: 512
|
||||||
|
key: class_label
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 64
|
||||||
|
num_workers: 12
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.imagenet.ImageNetTrain
|
||||||
|
params:
|
||||||
|
config:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: ldm.data.imagenet.ImageNetValidation
|
||||||
|
params:
|
||||||
|
config:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,68 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 0.0001
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: class_label
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 192
|
||||||
|
attention_resolutions:
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 5
|
||||||
|
num_heads: 1
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 512
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.ClassEmbedder
|
||||||
|
params:
|
||||||
|
n_classes: 1001
|
||||||
|
embed_dim: 512
|
||||||
|
key: class_label
|
|
@ -0,0 +1,85 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 2.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 224
|
||||||
|
attention_resolutions:
|
||||||
|
# note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 64 for f4
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config: __is_unconditional__
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 42
|
||||||
|
num_workers: 5
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: taming.data.faceshq.FFHQTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: taming.data.faceshq.FFHQValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,85 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 2.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0195
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
image_size: 64
|
||||||
|
channels: 3
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 64
|
||||||
|
in_channels: 3
|
||||||
|
out_channels: 3
|
||||||
|
model_channels: 224
|
||||||
|
attention_resolutions:
|
||||||
|
# note: this isn\t actually the resolution but
|
||||||
|
# the downsampling factor, i.e. this corresnponds to
|
||||||
|
# attention on spatial resolution 8,16,32, as the
|
||||||
|
# spatial reolution of the latents is 64 for f4
|
||||||
|
- 8
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
num_head_channels: 32
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.VQModelInterface
|
||||||
|
params:
|
||||||
|
ckpt_path: configs/first_stage_models/vq-f4/model.yaml
|
||||||
|
embed_dim: 3
|
||||||
|
n_embed: 8192
|
||||||
|
ddconfig:
|
||||||
|
double_z: false
|
||||||
|
z_channels: 3
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config: __is_unconditional__
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 48
|
||||||
|
num_workers: 5
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.lsun.LSUNBedroomsTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: ldm.data.lsun.LSUNBedroomsValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,91 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.0155
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
loss_type: l1
|
||||||
|
first_stage_key: "image"
|
||||||
|
cond_stage_key: "image"
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: False
|
||||||
|
concat_mode: False
|
||||||
|
scale_by_std: True
|
||||||
|
monitor: 'val/loss_simple_ema'
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [10000]
|
||||||
|
cycle_lengths: [10000000000000]
|
||||||
|
f_start: [1.e-6]
|
||||||
|
f_max: [1.]
|
||||||
|
f_min: [ 1.]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 192
|
||||||
|
attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
|
||||||
|
num_heads: 8
|
||||||
|
use_scale_shift_norm: True
|
||||||
|
resblock_updown: True
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: "val/rec_loss"
|
||||||
|
ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
|
||||||
|
ddconfig:
|
||||||
|
double_z: True
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: [ ]
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config: "__is_unconditional__"
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 96
|
||||||
|
num_workers: 5
|
||||||
|
wrap: False
|
||||||
|
train:
|
||||||
|
target: ldm.data.lsun.LSUNChurchesTrain
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
validation:
|
||||||
|
target: ldm.data.lsun.LSUNChurchesValidation
|
||||||
|
params:
|
||||||
|
size: 256
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 5000
|
||||||
|
max_images: 8
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
|
@ -0,0 +1,71 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-05
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.012
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 32
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 1280
|
||||||
|
use_checkpoint: true
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.BERTEmbedder
|
||||||
|
params:
|
||||||
|
n_embed: 1280
|
||||||
|
n_layer: 32
|
|
@ -0,0 +1,68 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 0.0001
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.0015
|
||||||
|
linear_end: 0.015
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: jpg
|
||||||
|
cond_stage_key: nix
|
||||||
|
image_size: 48
|
||||||
|
channels: 16
|
||||||
|
cond_stage_trainable: false
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_by_std: false
|
||||||
|
scale_factor: 0.22765929
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 48
|
||||||
|
in_channels: 16
|
||||||
|
out_channels: 16
|
||||||
|
model_channels: 448
|
||||||
|
attention_resolutions:
|
||||||
|
- 4
|
||||||
|
- 2
|
||||||
|
- 1
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 3
|
||||||
|
- 4
|
||||||
|
use_scale_shift_norm: false
|
||||||
|
resblock_updown: false
|
||||||
|
num_head_channels: 32
|
||||||
|
use_spatial_transformer: true
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: true
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
monitor: val/rec_loss
|
||||||
|
embed_dim: 16
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 16
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions:
|
||||||
|
- 16
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
cond_stage_config:
|
||||||
|
target: torch.nn.Identity
|
|
@ -0,0 +1,113 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 5.0e-06
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 512
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.local.LocalBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
mode: "train"
|
||||||
|
validation:
|
||||||
|
target: ldm.data.local.LocalBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
mode: "val"
|
||||||
|
val_split: 64
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 500
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 500
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
log_first_step: False
|
||||||
|
log_images_kwargs:
|
||||||
|
use_ema_scope: False
|
||||||
|
inpaint: False
|
||||||
|
plot_progressive_rows: False
|
||||||
|
plot_diffusion_rows: False
|
||||||
|
N: 4
|
||||||
|
ddim_steps: 50
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
val_check_interval: 5000000
|
||||||
|
num_sanity_val_steps: 0
|
||||||
|
accumulate_grad_batches: 1
|
|
@ -0,0 +1,100 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 50
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: image
|
||||||
|
cond_stage_key: caption
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: true # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 1 ] # NOTE for resuming. use 10000 if starting from scratch
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 512
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||||
|
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 1
|
||||||
|
wrap: false
|
||||||
|
train:
|
||||||
|
target: ldm.data.local.LocalBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
validation:
|
||||||
|
target: ldm.data.local.LocalBase
|
||||||
|
params:
|
||||||
|
size: 512
|
||||||
|
|
||||||
|
lightning:
|
||||||
|
modelcheckpoint:
|
||||||
|
params:
|
||||||
|
every_n_train_steps: 500
|
||||||
|
callbacks:
|
||||||
|
image_logger:
|
||||||
|
target: main.ImageLogger
|
||||||
|
params:
|
||||||
|
batch_frequency: 500
|
||||||
|
max_images: 4
|
||||||
|
increase_log_steps: False
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
benchmark: True
|
||||||
|
max_steps: 6100
|
|
@ -0,0 +1,70 @@
|
||||||
|
model:
|
||||||
|
base_learning_rate: 1.0e-04
|
||||||
|
target: ldm.models.diffusion.ddpm.LatentDiffusion
|
||||||
|
params:
|
||||||
|
linear_start: 0.00085
|
||||||
|
linear_end: 0.0120
|
||||||
|
num_timesteps_cond: 1
|
||||||
|
log_every_t: 200
|
||||||
|
timesteps: 1000
|
||||||
|
first_stage_key: "jpg"
|
||||||
|
cond_stage_key: "txt"
|
||||||
|
image_size: 64
|
||||||
|
channels: 4
|
||||||
|
cond_stage_trainable: false # Note: different from the one we trained before
|
||||||
|
conditioning_key: crossattn
|
||||||
|
monitor: val/loss_simple_ema
|
||||||
|
scale_factor: 0.18215
|
||||||
|
use_ema: False
|
||||||
|
|
||||||
|
scheduler_config: # 10000 warmup steps
|
||||||
|
target: ldm.lr_scheduler.LambdaLinearScheduler
|
||||||
|
params:
|
||||||
|
warm_up_steps: [ 10000 ]
|
||||||
|
cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
|
||||||
|
f_start: [ 1.e-6 ]
|
||||||
|
f_max: [ 1. ]
|
||||||
|
f_min: [ 1. ]
|
||||||
|
|
||||||
|
unet_config:
|
||||||
|
target: ldm.modules.diffusionmodules.openaimodel.UNetModel
|
||||||
|
params:
|
||||||
|
image_size: 32 # unused
|
||||||
|
in_channels: 4
|
||||||
|
out_channels: 4
|
||||||
|
model_channels: 320
|
||||||
|
attention_resolutions: [ 4, 2, 1 ]
|
||||||
|
num_res_blocks: 2
|
||||||
|
channel_mult: [ 1, 2, 4, 4 ]
|
||||||
|
num_heads: 8
|
||||||
|
use_spatial_transformer: True
|
||||||
|
transformer_depth: 1
|
||||||
|
context_dim: 768
|
||||||
|
use_checkpoint: True
|
||||||
|
legacy: False
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: ldm.models.autoencoder.AutoencoderKL
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
double_z: true
|
||||||
|
z_channels: 4
|
||||||
|
resolution: 256
|
||||||
|
in_channels: 3
|
||||||
|
out_ch: 3
|
||||||
|
ch: 128
|
||||||
|
ch_mult:
|
||||||
|
- 1
|
||||||
|
- 2
|
||||||
|
- 4
|
||||||
|
- 4
|
||||||
|
num_res_blocks: 2
|
||||||
|
attn_resolutions: []
|
||||||
|
dropout: 0.0
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
|
||||||
|
cond_stage_config:
|
||||||
|
target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
|
@ -0,0 +1,80 @@
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import requests
|
||||||
|
import multiprocessing
|
||||||
|
import tqdm
|
||||||
|
|
||||||
|
# downloads URLs from JSON
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--file', '-f', type=str, required=False)
|
||||||
|
parser.add_argument('--out_dir', '-o', type=str, required=False)
|
||||||
|
parser.add_argument('--threads', '-p', required=False, default=32)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
class DownloadManager():
|
||||||
|
def __init__(self, max_threads=32):
|
||||||
|
self.failed_downloads = []
|
||||||
|
self.max_threads = max_threads
|
||||||
|
|
||||||
|
# args = (link, metadata, out_img_dir, out_text_dir)
|
||||||
|
def download(self, args):
|
||||||
|
try:
|
||||||
|
r = requests.get(args[0], stream=True)
|
||||||
|
with open(args[2] + args[0].split('/')[-1], 'wb') as f:
|
||||||
|
for chunk in r.iter_content(1024):
|
||||||
|
f.write(chunk)
|
||||||
|
with open(args[3] + args[0].split('/')[-1].split('.')[0] + '.txt', 'w') as f:
|
||||||
|
f.write(args[1])
|
||||||
|
except:
|
||||||
|
self.failed_downloads.append((args[0], args[1]))
|
||||||
|
|
||||||
|
def download_urls(self, file_path, out_dir):
|
||||||
|
with open(file_path) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
|
||||||
|
if not os.path.exists(out_dir):
|
||||||
|
os.makedirs(out_dir)
|
||||||
|
os.makedirs(out_dir + '/img')
|
||||||
|
os.makedirs(out_dir + '/text')
|
||||||
|
|
||||||
|
thread_args = []
|
||||||
|
|
||||||
|
print(f'Loading {file_path} for download on {self.max_threads} threads...')
|
||||||
|
|
||||||
|
# create initial thread_args
|
||||||
|
for k, v in tqdm.tqdm(data.items()):
|
||||||
|
thread_args.append((k, v, out_dir + 'img/', out_dir + 'text/'))
|
||||||
|
|
||||||
|
# divide thread_args into chunks divisible by max_threads
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(thread_args), self.max_threads):
|
||||||
|
chunks.append(thread_args[i:i+self.max_threads])
|
||||||
|
|
||||||
|
print(f'Downloading {len(thread_args)} images...')
|
||||||
|
|
||||||
|
# download chunks synchronously
|
||||||
|
for chunk in tqdm.tqdm(chunks):
|
||||||
|
with multiprocessing.Pool(self.max_threads) as p:
|
||||||
|
p.map(self.download, chunk)
|
||||||
|
|
||||||
|
if len(self.failed_downloads) > 0:
|
||||||
|
print("Failed downloads:")
|
||||||
|
for i in self.failed_downloads:
|
||||||
|
print(i[0])
|
||||||
|
print("\n")
|
||||||
|
"""
|
||||||
|
# attempt to download any remaining failed downloads
|
||||||
|
print('\nAttempting to download any failed downloads...')
|
||||||
|
print('Failed downloads:', len(self.failed_downloads))
|
||||||
|
if len(self.failed_downloads) > 0:
|
||||||
|
for url in tqdm.tqdm(self.failed_downloads):
|
||||||
|
self.download((url[0], url[1], out_dir + 'img/', out_dir + 'text/'))
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
dm = DownloadManager(max_threads=args.threads)
|
||||||
|
dm.download_urls(args.file, args.out_dir)
|
|
@ -0,0 +1,19 @@
|
||||||
|
#resizes and adds a black bar to all images in directory original
|
||||||
|
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
directory = 'original'
|
||||||
|
|
||||||
|
for filename in os.listdir(directory):
|
||||||
|
var1 = directory + '/' + filename
|
||||||
|
os.mkdir('E:/convert/original/' + filename)
|
||||||
|
for i in os.listdir(var1):
|
||||||
|
var4 = var1 + '/'
|
||||||
|
var2 = var1 + '/' + i
|
||||||
|
if os.path.isfile(var2):
|
||||||
|
print(var2)
|
||||||
|
im = Image.open(var2)
|
||||||
|
im = ImageOps.pad(im, (512, 512), color='black')
|
||||||
|
im.save('E:/convert/' + var2)
|
|
@ -0,0 +1,301 @@
|
||||||
|
## This script WAS NOT USED on the weights released by ProjectAI Touhou on 8th of september, 2022.
|
||||||
|
## This script CAN convert tags to human-readable-text BUT IT IS NOT REQUIRED.
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
#Stolen code from https://stackoverflow.com/a/43357954
|
||||||
|
def str2bool(v):
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return v
|
||||||
|
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||||
|
return True
|
||||||
|
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||||
|
|
||||||
|
def ratingparsing(input):
|
||||||
|
v = input.lower()
|
||||||
|
ratingsSelected = " "
|
||||||
|
if "a" in v:
|
||||||
|
ratingsSelected = "e g q s"
|
||||||
|
if "e" in v:
|
||||||
|
ratingsSelected = ratingsSelected + "e "
|
||||||
|
if "g" in v:
|
||||||
|
ratingsSelected = ratingsSelected + "g "
|
||||||
|
if "q" in v:
|
||||||
|
ratingsSelected = ratingsSelected + "q "
|
||||||
|
if "s" in v:
|
||||||
|
ratingsSelected = ratingsSelected + "s "
|
||||||
|
if ratingsSelected == " ":
|
||||||
|
raise Exception('a/e/g/q/s expected')
|
||||||
|
print("Ratings selected: " + ratingsSelected)
|
||||||
|
return(ratingsSelected)
|
||||||
|
## In the future someone might want to access this via import. Consider adding support for that
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--jsonpath', '-J', type=str, help='Path to JSONL file with the metadata', required = True)
|
||||||
|
parser.add_argument('--extractpath', '-E', type=str, help='Path to the folder where to extract the images and text files', required = True)
|
||||||
|
parser.add_argument('--imagespath', '-I', type=str, help='Path to the folder with the images', required = False, default="512px")
|
||||||
|
parser.add_argument('--convtohuman', '-H', type=str2bool, help='Convert to human-readable-text', required = False, default=False)
|
||||||
|
parser.add_argument('--rating', '-R', type=ratingparsing, help='Extract specific rating/s [a/e/g/q/s]', required = False, default='a')
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.convtohuman == True:
|
||||||
|
print("tag conversion to human is currently somewhat broken. If you still want to use it remove line 25")
|
||||||
|
#Q: What is broken?
|
||||||
|
#A: tag_separator sometimes appears at to_write without anything behind it. It should be an easy fix where tag_separator simply does not appear if the variable behind it is blank
|
||||||
|
#but right now its not important, plus many tokens are lost when converting to human text. its more effective doing tag based inputs rather than human-readable text
|
||||||
|
exit()
|
||||||
|
print("Arguments: " + str(args))
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
if os.path.exists(args.extractpath) == False:
|
||||||
|
os.mkdir(args.extractpath)
|
||||||
|
|
||||||
|
def writefile(filename, text):
|
||||||
|
f = open(filename, "w")
|
||||||
|
f.write(text)
|
||||||
|
print('Saved the following: ' + text)
|
||||||
|
f.close()
|
||||||
|
#Converts tags to T2I-like prompts (blue_dress, 1girl -> A blue dress, one girl)
|
||||||
|
|
||||||
|
def ConvCommaAndUnderscoreToHuman(convtohuman, input):
|
||||||
|
tars = input
|
||||||
|
if convtohuman:
|
||||||
|
tars = tars.replace(' ', ', ')
|
||||||
|
tars = tars.replace('_', ' ')
|
||||||
|
elif convtohuman == False:
|
||||||
|
print("CommaAndUnderscoreToHuman: convtohuman is false hence not doing anything")
|
||||||
|
return tars
|
||||||
|
|
||||||
|
def ConvTagsToHuman(convtohuman, input):
|
||||||
|
tars = input
|
||||||
|
if convtohuman:
|
||||||
|
tars = tars.replace('1girl', 'one girl')
|
||||||
|
tars = tars.replace('2girls', 'two girls')
|
||||||
|
tars = tars.replace('3girls', 'three girls')
|
||||||
|
tars = tars.replace('4girls', 'four girls')
|
||||||
|
tars = tars.replace('5girls', 'five girls')
|
||||||
|
##Implying it will ever be able to differentiate so many entities
|
||||||
|
tars = tars.replace('6girls', 'six girls')
|
||||||
|
|
||||||
|
#Almost forgot about boys tags... I wonder if theres also for other entities?
|
||||||
|
tars = tars.replace('1boy', 'one boy')
|
||||||
|
tars = tars.replace('2boys', 'two boys')
|
||||||
|
tars = tars.replace('3boys', 'three boys')
|
||||||
|
tars = tars.replace('4boys', 'four boys')
|
||||||
|
tars = tars.replace('5boys', 'five boys')
|
||||||
|
tars = tars.replace('6boys', 'six boys')
|
||||||
|
elif convtohuman == False:
|
||||||
|
print("ConvTagsToHuman: convtohuman is false hence not doing anything")
|
||||||
|
print("TARS is: " + tars)
|
||||||
|
return tars
|
||||||
|
|
||||||
|
#Converts ratings to X content
|
||||||
|
def ConvRatingToHuman(convtohuman, input):
|
||||||
|
if convtohuman:
|
||||||
|
if input == "e":
|
||||||
|
return "explicit content"
|
||||||
|
if input == "g":
|
||||||
|
return "general content"
|
||||||
|
if input == "q":
|
||||||
|
return "questionable content"
|
||||||
|
if input == "s":
|
||||||
|
return "sensitive content"
|
||||||
|
##This will be the start of everything unethical
|
||||||
|
elif convtohuman == False:
|
||||||
|
if input == "e":
|
||||||
|
return "explicit_content"
|
||||||
|
if input == "g":
|
||||||
|
return "general_content"
|
||||||
|
if input == "q":
|
||||||
|
return "questionable_content"
|
||||||
|
if input == "s":
|
||||||
|
return "sensitive_content"
|
||||||
|
|
||||||
|
def ConvCharacterToHuman(convtohuman, input):
|
||||||
|
tars = input
|
||||||
|
if convtohuman:
|
||||||
|
tars = tars.replace('_(', ' from ')
|
||||||
|
tars = tars.replace(')', '')
|
||||||
|
elif convtohuman == False:
|
||||||
|
print("ConvCharacterToHuman: convtohuman is false hence not doing anything")
|
||||||
|
return tars
|
||||||
|
|
||||||
|
# unrecog_ans = True
|
||||||
|
# while unrecog_ans:
|
||||||
|
# inputans = input("Convert tags to human-readable-text? (smiley_face blue_hair -> smiley face, blue hair) [y/n]")
|
||||||
|
# if inputans == "y":
|
||||||
|
# convtohuman = True
|
||||||
|
# unrecog_ans = False
|
||||||
|
# elif inputans == "n":
|
||||||
|
# convtohuman = False
|
||||||
|
# unrecog_ans = False
|
||||||
|
# else:
|
||||||
|
# print("unrecognizable input. only y or n.")
|
||||||
|
# unrecog_ans = True
|
||||||
|
|
||||||
|
convtohuman = args.convtohuman
|
||||||
|
acceptedRatings = args.rating
|
||||||
|
|
||||||
|
##Open the file
|
||||||
|
json_file_path = args.jsonpath ##Name of the JSON file to use, converted into parser arg
|
||||||
|
with open(json_file_path, 'r', encoding="utf8") as json_file:
|
||||||
|
json_list = list(json_file)
|
||||||
|
|
||||||
|
##Read line
|
||||||
|
current_saved_file_count = 0
|
||||||
|
current_line_count = 0
|
||||||
|
for json_str in json_list:
|
||||||
|
current_line_count = current_line_count + 1
|
||||||
|
##415627 last line of 00.json, ignore
|
||||||
|
##TODO: Add a line counter to print progress accurately
|
||||||
|
print("Current Line:" + str(current_line_count) + '/415000 (aprox) | Current saved files count: ' + str(current_saved_file_count) )
|
||||||
|
#here, result = line
|
||||||
|
result = json.loads(json_str)
|
||||||
|
|
||||||
|
try:
|
||||||
|
img_id = str(result['id'])
|
||||||
|
except Exception:
|
||||||
|
img_id = "nan"
|
||||||
|
print("img_id RETRIVAL FAILED. VAR IS ESSENTIAL SO SKIPPING ENTRY.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
tmp_img_id = img_id[-3:]
|
||||||
|
img_id_last3 = tmp_img_id.zfill(3)
|
||||||
|
except Exception:
|
||||||
|
img_id_last3 = "nan"
|
||||||
|
print("img_id_last3 RETRIVAL FAILED. VAR IS ESSENTIAL SO SKIPPING ENTRY.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# try:
|
||||||
|
# img_tags = result['tag_string']
|
||||||
|
# except Exception:
|
||||||
|
# img_tags = "none"
|
||||||
|
# print("failed to get img_tags")
|
||||||
|
# continue
|
||||||
|
|
||||||
|
##JohannesGaessler SUGGESTIONS: harubaru/waifu-diffusion/pull/11
|
||||||
|
|
||||||
|
## TAG_STRING_GENERAL: ONLY TAGS HERE
|
||||||
|
try:
|
||||||
|
img_tag_string_general = result['tag_string_general']
|
||||||
|
except Exception:
|
||||||
|
img_tag_string_general = None
|
||||||
|
print("img_tag_string_general RETRIVAL FAILED. VAR IS ESSENTIAL SO SKIPPING ENTRY.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
## TAG_STRING_ARTIST: ONLY ARTISTS TAGS HERE
|
||||||
|
try:
|
||||||
|
img_tag_string_artist = result['tag_string_artist']
|
||||||
|
except Exception:
|
||||||
|
img_tag_string_artist = None
|
||||||
|
print("img_tag_string_artist RETRIVAL FAILED. Var is not essential so just skipping var.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
## TAG_STRING_COPYRIGHT: ONLY COPYRIGHT TAGS HERE
|
||||||
|
try:
|
||||||
|
img_tag_string_copyright = result['tag_string_copyright']
|
||||||
|
except Exception:
|
||||||
|
img_tag_string_copyright = None
|
||||||
|
print("img_tag_string_copyright RETRIVAL FAILED. Var is not essential so just skipping var.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
## TAG_STRING_CHARACTER: ONLY CHARACTER TAGS HERE
|
||||||
|
try:
|
||||||
|
img_tag_string_character = result['tag_string_character']
|
||||||
|
except Exception:
|
||||||
|
img_tag_string_character = None
|
||||||
|
print("img_tag_string_character RETRIVAL FAILED. Var is not essential so just skipping var.")
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
img_ext = result['file_ext']
|
||||||
|
except Exception:
|
||||||
|
img_ext = None
|
||||||
|
print("img_ext RETRIVAL FAILED. VAR IS ESSENTIAL SO SKIPPING ENTRY.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
img_rating = result['rating']
|
||||||
|
except Exception:
|
||||||
|
img_rating = None
|
||||||
|
print("img_rating RETRIVAL FAILED. VAR IS ESSENTIAL SO SKIPPING ENTRY.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
baru = img_rating in acceptedRatings
|
||||||
|
|
||||||
|
# print("HEYYYYYYYYYYYYYYYY " + str(baru))
|
||||||
|
|
||||||
|
if str(baru) == "False":
|
||||||
|
print("Entry rating' is not in acceptedRatings, skipping entry.")
|
||||||
|
continue
|
||||||
|
elif str(baru) == "True":
|
||||||
|
print("Entry rating matches!")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
file_path = str(args.imagespath) + "/0" + img_id_last3 + "/" + img_id + "." + img_ext
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
shutil.copyfile(file_path, args.extractpath + '/' + img_id + "." + img_ext)
|
||||||
|
|
||||||
|
##Essential
|
||||||
|
FinalTagStringGeneral = ConvCommaAndUnderscoreToHuman(convtohuman, img_tag_string_general)
|
||||||
|
FinalTagStringGeneral = ConvTagsToHuman(convtohuman, FinalTagStringGeneral)
|
||||||
|
|
||||||
|
##Not essential
|
||||||
|
if img_tag_string_artist != None:
|
||||||
|
FinalTagStringArtist = ConvCommaAndUnderscoreToHuman(convtohuman, img_tag_string_artist)
|
||||||
|
elif img_tag_string_artist == None:
|
||||||
|
print("img_tag_string_artist is none")
|
||||||
|
else:
|
||||||
|
print("CE 1NE")
|
||||||
|
|
||||||
|
if img_tag_string_character != None:
|
||||||
|
FinalTagStringCharacter = ConvCommaAndUnderscoreToHuman(convtohuman, img_tag_string_character)
|
||||||
|
FinalTagStringCharacter = ConvCharacterToHuman(convtohuman, FinalTagStringCharacter)
|
||||||
|
elif img_tag_string_character == None:
|
||||||
|
print("img_tag_string_character is none")
|
||||||
|
else:
|
||||||
|
print("CE 2NE")
|
||||||
|
|
||||||
|
if img_tag_string_copyright != None:
|
||||||
|
FinalTagStringCopyright = ConvCommaAndUnderscoreToHuman(convtohuman, img_tag_string_copyright)
|
||||||
|
elif img_tag_string_copyright == None:
|
||||||
|
print("img_tag_string_copyright is none")
|
||||||
|
else:
|
||||||
|
print("CE 3NE")
|
||||||
|
|
||||||
|
print("IMAGE RATING IS: " + img_rating)
|
||||||
|
|
||||||
|
if img_rating != None:
|
||||||
|
FinalTagStringRating = ConvRatingToHuman(convtohuman, img_rating)
|
||||||
|
elif img_rating == None:
|
||||||
|
print("img_rating is none")
|
||||||
|
else:
|
||||||
|
print("CE 4NE")
|
||||||
|
|
||||||
|
if convtohuman == True:
|
||||||
|
dan_iden = 'uploaded on danbooru'
|
||||||
|
tag_separator = ', '
|
||||||
|
elif convtohuman == False:
|
||||||
|
dan_iden = 'danbooru'
|
||||||
|
tag_separator = ' '
|
||||||
|
# print('FinalTagStringCharacter is: ' + FinalTagStringCharacter)
|
||||||
|
# print('tag_separator is: ' + tag_separator)
|
||||||
|
# print('FinalTagStringArtist is: ' + FinalTagStringArtist)
|
||||||
|
# print('FinalTagStringRating is: ' + FinalTagStringRating)
|
||||||
|
# print('FinalTagStringGeneral is: ' + FinalTagStringGeneral)
|
||||||
|
# print('FinalTagStringCopyright is: ' + FinalTagStringCopyright)
|
||||||
|
to_write = FinalTagStringCharacter + tag_separator + FinalTagStringArtist + tag_separator + FinalTagStringRating + tag_separator + FinalTagStringGeneral + tag_separator + FinalTagStringCopyright
|
||||||
|
txt_name = args.extractpath + "/" + img_id + '.txt'
|
||||||
|
writefile(txt_name, to_write)
|
||||||
|
current_saved_file_count = current_saved_file_count + 1
|
||||||
|
elif os.path.exists(file_path) == False:
|
||||||
|
print("Failed to find path.")
|
||||||
|
|
||||||
|
print("finished process. Your extracted data should be in " + str(args.extractpath) + " !")
|
||||||
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
# with open("nsfw-ids.txt", 'r', encoding="utf8") as nsfwfile:
|
||||||
|
# nsfw_list = list(nsfwfile)
|
||||||
|
import tqdm
|
||||||
|
# ##Read line
|
||||||
|
# current_saved_file_count = 0
|
||||||
|
# current_line_count = 0
|
||||||
|
# for line in nsfw_list:
|
||||||
|
# print(line)
|
||||||
|
# last3_line_raw = line[-4:]
|
||||||
|
# last3_line = last3_line_raw.zfill(4)
|
||||||
|
# print(last3_line_raw)
|
||||||
|
# print(last3_line)
|
||||||
|
|
||||||
|
def file_len(filename):
|
||||||
|
with open(filename) as f:
|
||||||
|
for i, _ in enumerate(f):
|
||||||
|
pass
|
||||||
|
return i + 1
|
||||||
|
|
||||||
|
def writetofile(input):
|
||||||
|
f = open("files2download.txt", "a")
|
||||||
|
f.write(input + "\n")
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
#converts nsfw-ids.txt entries to rsync readable file
|
||||||
|
|
||||||
|
with open("nsfw-ids.txt", 'r', encoding="utf8") as nsfwfile:
|
||||||
|
nsfw_list = list(nsfwfile)
|
||||||
|
count = 0
|
||||||
|
linescount = file_len("nsfw-ids.txt")
|
||||||
|
|
||||||
|
##Read line
|
||||||
|
for line in nsfw_list:
|
||||||
|
line = line.strip()
|
||||||
|
# print(line)
|
||||||
|
linefilled1 = line.zfill(4)
|
||||||
|
linelast3 = linefilled1[-3:]
|
||||||
|
linedirectory = linelast3.zfill(4)
|
||||||
|
# print("line: " + ">>"+ line + "<<")
|
||||||
|
# print("Linefilled1: " + linefilled1)
|
||||||
|
# print("linelast3: " + linelast3)
|
||||||
|
# print("linedirectory: " + linedirectory)
|
||||||
|
directory = "original/" + linedirectory + "/" + line + ".jpg"
|
||||||
|
# print(directory)
|
||||||
|
# print(directory2)
|
||||||
|
writetofile(directory)
|
||||||
|
count = count + 1
|
||||||
|
print(str(count) + "/" + str(linescount))
|
|
@ -0,0 +1,50 @@
|
||||||
|
import threading
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
from pybooru import Danbooru
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--danbooru_username', '-user', type=str, required=False)
|
||||||
|
parser.add_argument('--danbooru_key', '-key', type=str, required=False)
|
||||||
|
parser.add_argument('--tags', '-t', required=False, default="solo -comic -animated -touhou -rating:general order:score age:<1month")
|
||||||
|
parser.add_argument('--posts', '-p', required=False, default=10000)
|
||||||
|
parser.add_argument('--output', '-o', required=False, default='links.json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
class DanbooruScraper():
|
||||||
|
def __init__(self, username, key):
|
||||||
|
self.username = username
|
||||||
|
self.key = key
|
||||||
|
self.dbclient = Danbooru('danbooru', username=self.username, api_key=self.key)
|
||||||
|
|
||||||
|
# This will get danbooru urls and tags, put them in a dict, then write as a json file
|
||||||
|
def get_urls(self, tags, num_posts, batch_size, file="data_urls.json"):
|
||||||
|
dict = {}
|
||||||
|
if num_posts % batch_size != 0:
|
||||||
|
print("Error: num_posts must be divisible by batch_size")
|
||||||
|
return
|
||||||
|
for i in tqdm(range(num_posts//batch_size)):
|
||||||
|
urls = self.dbclient.post_list(tags=tags, limit=batch_size, random=False, page=i)
|
||||||
|
if not urls:
|
||||||
|
print(f'Empty results at {i}')
|
||||||
|
break
|
||||||
|
for j in urls:
|
||||||
|
if 'file_url' in j:
|
||||||
|
if j['file_url'] not in dict:
|
||||||
|
d_url = j['file_url']
|
||||||
|
d_tags = j['tag_string_copyright'] + " " + j['tag_string_character'] + " " + j['tag_string_general'] + " " + j['tag_string_artist']
|
||||||
|
|
||||||
|
dict[d_url] = d_tags
|
||||||
|
else:
|
||||||
|
print("Error: file_url not found")
|
||||||
|
with open(file, 'w') as f:
|
||||||
|
json.dump(dict, f)
|
||||||
|
|
||||||
|
# now test
|
||||||
|
if __name__ == "__main__":
|
||||||
|
ds = DanbooruScraper(args.danbooru_username, args.danbooru_key)
|
||||||
|
ds.get_urls(args.tags, args.posts, 100, file=args.output)
|
After Width: | Height: | Size: 14 KiB |
|
@ -0,0 +1 @@
|
||||||
|
A basket of cerries
|
After Width: | Height: | Size: 466 KiB |
After Width: | Height: | Size: 7.4 KiB |
After Width: | Height: | Size: 539 KiB |
After Width: | Height: | Size: 7.6 KiB |
After Width: | Height: | Size: 450 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 553 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 418 KiB |
After Width: | Height: | Size: 6.1 KiB |
After Width: | Height: | Size: 542 KiB |
After Width: | Height: | Size: 9.5 KiB |
After Width: | Height: | Size: 395 KiB |
After Width: | Height: | Size: 12 KiB |
After Width: | Height: | Size: 465 KiB |
After Width: | Height: | Size: 7.8 KiB |
|
@ -0,0 +1,7 @@
|
||||||
|
# Documentation
|
||||||
|
|
||||||
|
Waifu Diffusion is a project based off CompVis/Stable-Diffusion.
|
||||||
|
|
||||||
|
For guidance on how to start training, see [training](./training/README.md).
|
||||||
|
|
||||||
|
For a list of trained weights, see [weights](./weights/README.md).
|
|
@ -0,0 +1,8 @@
|
||||||
|
# Training documentation
|
||||||
|
Training is available with waifu-diffusion. Before starting, we remind you that, at this moment at least 30GB of VRAM is needed, along with at least 30gb of storage if you don't mind cleaning up every so often.
|
||||||
|
## Contents
|
||||||
|
1. [Dataset](./dataset.md)
|
||||||
|
2. [Configuration](./configuration.md)
|
||||||
|
3. [Executing](./executing.md)
|
||||||
|
4. Recommendations
|
||||||
|
5. FAQ
|
|
@ -0,0 +1,3 @@
|
||||||
|
# 2. Configuration
|
||||||
|
This section is to be done on the machine where you are going to train.
|
||||||
|
Soon because my instance is on maintenance
|
|
@ -0,0 +1,120 @@
|
||||||
|
# 1. Dataset
|
||||||
|
|
||||||
|
In this guide we are going to use the Danbooru2021 dataset by Gwern.net. You are free to use any other dataset as long as you know how to convert it to the right format.
|
||||||
|
|
||||||
|
## Contents
|
||||||
|
1. Dataset requirements
|
||||||
|
2. Downloading the dataset
|
||||||
|
3. Organizing the dataset
|
||||||
|
4. Packaging the dataset
|
||||||
|
|
||||||
|
## Dataset requirements
|
||||||
|
|
||||||
|
The dataset needs to be in the following format
|
||||||
|
|
||||||
|
/dataset/ : Root dataset folder, can be any name
|
||||||
|
|
||||||
|
/dataset/img/ : Folder for images
|
||||||
|
|
||||||
|
/dataset/txt/ : Folder for text files
|
||||||
|
|
||||||
|
It is recommended to have the images in 512x512 resolution and in JPG format. While the text files need to have the same name as the images it refers to.
|
||||||
|
|
||||||
|
Foe example:
|
||||||
|
````
|
||||||
|
mydataset
|
||||||
|
├── img
|
||||||
|
│ └── image001.jpg
|
||||||
|
└── txt
|
||||||
|
└── image001.txt
|
||||||
|
````
|
||||||
|
Where image001.txt has the tags (prompt) to be used for image001.jpg
|
||||||
|
|
||||||
|
## Downloading the dataset
|
||||||
|
This is optional; If you have your own dataset skip this part.
|
||||||
|
|
||||||
|
### Downloading Rsync
|
||||||
|
Danbooru2021 is available for download through rsync.
|
||||||
|
#### Linux
|
||||||
|
On Linux, you should be able to install rsync via your package manager.
|
||||||
|
````bash
|
||||||
|
apt install rsync
|
||||||
|
````
|
||||||
|
#### Windows
|
||||||
|
On Windows, you are going to need to install Cygwin, a posix runtime for Windows which allows the usage of many linux-only programs inside windows.
|
||||||
|
|
||||||
|
[Cygwin Installer for x86](https://www.cygwin.com/setup-x86_64.exe)
|
||||||
|
|
||||||
|
On the installer, select mirrors.kernel.org for Download Site:
|
||||||
|
|
||||||
|
![cygwin-mirrors.png](./res/cygwin-mirrors.png)
|
||||||
|
|
||||||
|
Next, search for "rsync" on the search bar, change "View: Pending" to "View: Full", and select on the "New" tab the latest version. Do the same for "zip".
|
||||||
|
|
||||||
|
![cygwin-packages.png](./res/cygwin-packages.png)
|
||||||
|
|
||||||
|
GIF explaining the entire process:
|
||||||
|
|
||||||
|
![cygwin-gif.gif](./res/cygwin-gif.gif)
|
||||||
|
|
||||||
|
Once the installation is finished, you should see "Cygwin64 Terminal" on your Start Menu. Launch it and you should be greated by the following window:
|
||||||
|
|
||||||
|
![cygwin-idle.png](./res/cygwin-idle.png)
|
||||||
|
|
||||||
|
You may now follow the intructions
|
||||||
|
|
||||||
|
### Downloading the dataset
|
||||||
|
Remember that instructions here apply universally, both on Linux and Windows (If you are using Cygwin that is).
|
||||||
|
|
||||||
|
The entire dataset weights about 5TB. You are not going to download everything, instead, you are only going to download two kinds of files:
|
||||||
|
|
||||||
|
1. The images
|
||||||
|
2. The JSON files (metadata)
|
||||||
|
|
||||||
|
If you want to see the entire file list, you can refer to the [Danbooru2021 information site](https://www.gwern.net/Danbooru2021).
|
||||||
|
|
||||||
|
We are going to extract the images from the 512px folder for convinience, since this folder already has the images resized to 512x512 resolution in JPG format. It only has safe rated images, for NSFW refer to [gwern.net](https://www.gwern.net/Danbooru2021#samples).
|
||||||
|
|
||||||
|
Folders from 0000 to 0009.
|
||||||
|
> The folders are named according to the last 3 digits of the image ID on danbooru. Images on folder 0001 will have its ID end on 001.
|
||||||
|
|
||||||
|
We are also going to download the only the first JSON batch. If you want to train on more data you should download more JSON batches.
|
||||||
|
|
||||||
|
Download the 512px folders from 0000 to 0009 (3.86GB):
|
||||||
|
```bash
|
||||||
|
rsync -r rsync://176.9.41.242:873/danbooru2021/512px/000* ./512px/
|
||||||
|
```
|
||||||
|
Download the first batch of metadata, posts000000000000.json (800MB):
|
||||||
|
``` shell
|
||||||
|
rsync rsync://176.9.41.242:873/danbooru2021/metadata/posts000000000000.json ./metadata/
|
||||||
|
```
|
||||||
|
You should now have two folders named: 512px and metadata.
|
||||||
|
|
||||||
|
## Organizing the dataset
|
||||||
|
Although we have the dataset, the metadata that explains what the image is, is inside the JSON file. In order to extract the data into individual txt files, we are going to use the script inside ``danbooru_data/local/extractfromjson_danboo21.py``
|
||||||
|
|
||||||
|
Assuming you are in the same directory as metadata and 512px folder:
|
||||||
|
````bash
|
||||||
|
python danbooru_data/local/extractfromjson_danboo21.py -J metadata/posts000000000000.json -E danbooru-aesthetic
|
||||||
|
````
|
||||||
|
|
||||||
|
Once the script has finished, you should have a "danbooru-aesthetic" folder, whose insides look like this:
|
||||||
|
|
||||||
|
![labeled_data-insides.png](./res/labeled_data-insides.png)
|
||||||
|
|
||||||
|
## Packaging the dataset
|
||||||
|
Next we need to put the extracted data into the format required in the section "Dataset requirements". Run the following commands:
|
||||||
|
``` shell
|
||||||
|
mkdir danbooru-aesthetic/img danbooru-aesthetic/txt
|
||||||
|
mv danbooru-aesthetic/*.jpg danbooru-aesthetic/img
|
||||||
|
mv danbooru-aesthetic/*.txt danbooru-aesthetic/txt
|
||||||
|
```
|
||||||
|
|
||||||
|
In order to reduce size, zip the contents of labeled_data:
|
||||||
|
``` shell
|
||||||
|
zip -r danbooru-aesthetic.zip danbooru-aesthetic
|
||||||
|
```
|
||||||
|
This will package the entire danbooru-aesthetic folder into a zip file. This command DOES NOT output any information in the terminal, so be patient.
|
||||||
|
|
||||||
|
## Finish
|
||||||
|
You can now continue to Configure
|
|
@ -0,0 +1,51 @@
|
||||||
|
# 3. Executing
|
||||||
|
|
||||||
|
There are two modes of executing the training:
|
||||||
|
1. Using docker image. This is the fastest way to get started.
|
||||||
|
2. Using system python install. Allows more customization.
|
||||||
|
|
||||||
|
Note: You will need to provide the initial checkpoint for resuming the training. This must be a version with the full EMA. Otherwise you will get this error:
|
||||||
|
```
|
||||||
|
RuntimeError: Error(s) in loading state_dict for LatentDiffusion:
|
||||||
|
Missing key(s) in state_dict: "model_ema.diffusion_modeltime_embed0weight", "model_ema.diffusion_modeltime_embed0bias".... (Many lines of similar outputs)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 1. Using docker image
|
||||||
|
|
||||||
|
An image is provided at `ghcr.io/derfred/waifu-diffusion`. Execute it using by adjusting the NUM_GPU variable:
|
||||||
|
```
|
||||||
|
docker run -it -e NUM_GPU=x ghcr.io/derfred/waifu-diffusion
|
||||||
|
```
|
||||||
|
|
||||||
|
Next you will want to download the starting checkpoint into the file `model.ckpt` and copy the training data in the directory `/waifu/danbooru-aesthetic`.
|
||||||
|
|
||||||
|
Finally execute the training using:
|
||||||
|
```
|
||||||
|
sh train.sh -t -n "aesthetic" --resume_from_checkpoint model.ckpt --base ./configs/stable-diffusion/v1-finetune-4gpu.yaml --no-test --seed 25 --scale_lr False --data_root "./danbooru-aesthetic"
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. system python install
|
||||||
|
|
||||||
|
First install the dependencies:
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
Next you will want to download the starting checkpoint into the file `model.ckpt` and copy the training data in the directory `/waifu/danbooru-aesthetic`.
|
||||||
|
|
||||||
|
Also you will need to edit the configuration in `./configs/stable-diffusion/v1-finetune-4gpu.yaml`. In the `data` section (around line 70) change the `batch_size` and `num_workers` to the number of GPUs you are using:
|
||||||
|
```
|
||||||
|
data:
|
||||||
|
target: main.DataModuleFromConfig
|
||||||
|
params:
|
||||||
|
batch_size: 4
|
||||||
|
num_workers: 4
|
||||||
|
wrap: false
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally execute the training using the following command. You need to adjust the `--gpu` parameter according to your GPU settings.
|
||||||
|
```bash
|
||||||
|
sh train.sh -t -n "aesthetic" --resume_from_checkpoint model.ckpt --base ./configs/stable-diffusion/v1-finetune-4gpu.yaml --no-test --seed 25 --scale_lr False --data_root "./danbooru-aesthetic" --gpu=0,1,2,3,
|
||||||
|
```
|
||||||
|
|
||||||
|
In case you get an error stating `KeyError: 'Trying to restore optimizer state but checkpoint contains only the model. This is probably due to ModelCheckpoint.save_weights_only being set to True.'` follow these instructions: https://discord.com/channels/930499730843250783/953132470528798811/1018668937052962908
|
After Width: | Height: | Size: 83 MiB |
After Width: | Height: | Size: 4.6 KiB |
After Width: | Height: | Size: 20 KiB |
After Width: | Height: | Size: 153 KiB |
After Width: | Height: | Size: 173 KiB |
|
@ -0,0 +1,15 @@
|
||||||
|
# Weights
|
||||||
|
|
||||||
|
The following is a small list of available weights released by the Waifu Diffusion project:
|
||||||
|
|
||||||
|
- Waifu Diffusion v1.2
|
||||||
|
|
||||||
|
Release Date: 07/09/2022
|
||||||
|
|
||||||
|
Steps/Epochs/Images: 5 Epochs, 56,000 Images
|
||||||
|
|
||||||
|
Download: [Mirrors](./danbooru-7-09-2022/README.md)
|
||||||
|
|
||||||
|
License: None
|
||||||
|
|
||||||
|
Authors: Haru (haru#1367@discord)
|
|
@ -0,0 +1,19 @@
|
||||||
|
Waifu Diffusion v1.2
|
||||||
|
|
||||||
|
Release Date: 07/09/2022
|
||||||
|
|
||||||
|
Steps/Epochs/Images: 5 Epochs, 56,000 Images
|
||||||
|
|
||||||
|
License: None
|
||||||
|
|
||||||
|
Authors: Haru (haru#1367@discord)
|
||||||
|
|
||||||
|
Mirrors:
|
||||||
|
|
||||||
|
Google Drive (rate limit): https://drive.google.com/file/d/1XeoFCILTcc9kn_5uS-G0uqWS5XVANpha
|
||||||
|
|
||||||
|
Magnet Link: magnet:?xt=urn:btih:INEYUMLLBBMZF22IIP4AEXLUK6XQKCSD&dn=wd-v1-2-full-ema.ckpt&xl=7703810927&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
|
||||||
|
|
||||||
|
HTTPS mirror: https://thisanimedoesnotexist.ai/downloads/wd-v1-2-full-ema.ckpt (Fastest)
|
||||||
|
|
||||||
|
HTTP mirror: http://wd.links.sd:8880/wd-v1-2-full-ema.ckpt
|
|
@ -0,0 +1,32 @@
|
||||||
|
name: ldm
|
||||||
|
channels:
|
||||||
|
- pytorch
|
||||||
|
- defaults
|
||||||
|
dependencies:
|
||||||
|
- git
|
||||||
|
- python=3.8.5
|
||||||
|
- pip=20.3
|
||||||
|
- cudatoolkit=11.3
|
||||||
|
- pytorch=1.11.0
|
||||||
|
- torchvision=0.12.0
|
||||||
|
- numpy=1.19.2
|
||||||
|
- pip:
|
||||||
|
- albumentations==0.4.3
|
||||||
|
- opencv-python==4.1.2.30
|
||||||
|
- pudb==2019.2
|
||||||
|
- imageio==2.9.0
|
||||||
|
- imageio-ffmpeg==0.4.2
|
||||||
|
- pytorch-lightning==1.4.2
|
||||||
|
- omegaconf==2.1.1
|
||||||
|
- test-tube>=0.7.5
|
||||||
|
- streamlit>=0.73.1
|
||||||
|
- einops==0.3.0
|
||||||
|
- torch-fidelity==0.3.0
|
||||||
|
- transformers==4.19.2
|
||||||
|
- torchmetrics==0.6.0
|
||||||
|
- kornia==0.6
|
||||||
|
- gradio==3.1.6
|
||||||
|
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||||
|
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
|
- -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
|
||||||
|
- -e .
|
|
@ -0,0 +1,23 @@
|
||||||
|
from abc import abstractmethod
|
||||||
|
from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
|
||||||
|
|
||||||
|
|
||||||
|
class Txt2ImgIterableBaseDataset(IterableDataset):
|
||||||
|
'''
|
||||||
|
Define an interface to make the IterableDatasets for text2img data chainable
|
||||||
|
'''
|
||||||
|
def __init__(self, num_records=0, valid_ids=None, size=256):
|
||||||
|
super().__init__()
|
||||||
|
self.num_records = num_records
|
||||||
|
self.valid_ids = valid_ids
|
||||||
|
self.sample_ids = valid_ids
|
||||||
|
self.size = size
|
||||||
|
|
||||||
|
print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.num_records
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __iter__(self):
|
||||||
|
pass
|
|
@ -0,0 +1,394 @@
|
||||||
|
import os, yaml, pickle, shutil, tarfile, glob
|
||||||
|
import cv2
|
||||||
|
import albumentations
|
||||||
|
import PIL
|
||||||
|
import numpy as np
|
||||||
|
import torchvision.transforms.functional as TF
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from functools import partial
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.data import Dataset, Subset
|
||||||
|
|
||||||
|
import taming.data.utils as tdu
|
||||||
|
from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve
|
||||||
|
from taming.data.imagenet import ImagePaths
|
||||||
|
|
||||||
|
from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light
|
||||||
|
|
||||||
|
|
||||||
|
def synset2idx(path_to_yaml="data/index_synset.yaml"):
|
||||||
|
with open(path_to_yaml) as f:
|
||||||
|
di2s = yaml.load(f)
|
||||||
|
return dict((v,k) for k,v in di2s.items())
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetBase(Dataset):
|
||||||
|
def __init__(self, config=None):
|
||||||
|
self.config = config or OmegaConf.create()
|
||||||
|
if not type(self.config)==dict:
|
||||||
|
self.config = OmegaConf.to_container(self.config)
|
||||||
|
self.keep_orig_class_label = self.config.get("keep_orig_class_label", False)
|
||||||
|
self.process_images = True # if False we skip loading & processing images and self.data contains filepaths
|
||||||
|
self._prepare()
|
||||||
|
self._prepare_synset_to_human()
|
||||||
|
self._prepare_idx_to_synset()
|
||||||
|
self._prepare_human_to_integer_label()
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return self.data[i]
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def _filter_relpaths(self, relpaths):
|
||||||
|
ignore = set([
|
||||||
|
"n06596364_9591.JPEG",
|
||||||
|
])
|
||||||
|
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
|
||||||
|
if "sub_indices" in self.config:
|
||||||
|
indices = str_to_indices(self.config["sub_indices"])
|
||||||
|
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
|
||||||
|
self.synset2idx = synset2idx(path_to_yaml=self.idx2syn)
|
||||||
|
files = []
|
||||||
|
for rpath in relpaths:
|
||||||
|
syn = rpath.split("/")[0]
|
||||||
|
if syn in synsets:
|
||||||
|
files.append(rpath)
|
||||||
|
return files
|
||||||
|
else:
|
||||||
|
return relpaths
|
||||||
|
|
||||||
|
def _prepare_synset_to_human(self):
|
||||||
|
SIZE = 2655750
|
||||||
|
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
|
||||||
|
self.human_dict = os.path.join(self.root, "synset_human.txt")
|
||||||
|
if (not os.path.exists(self.human_dict) or
|
||||||
|
not os.path.getsize(self.human_dict)==SIZE):
|
||||||
|
download(URL, self.human_dict)
|
||||||
|
|
||||||
|
def _prepare_idx_to_synset(self):
|
||||||
|
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
|
||||||
|
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
|
||||||
|
if (not os.path.exists(self.idx2syn)):
|
||||||
|
download(URL, self.idx2syn)
|
||||||
|
|
||||||
|
def _prepare_human_to_integer_label(self):
|
||||||
|
URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1"
|
||||||
|
self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt")
|
||||||
|
if (not os.path.exists(self.human2integer)):
|
||||||
|
download(URL, self.human2integer)
|
||||||
|
with open(self.human2integer, "r") as f:
|
||||||
|
lines = f.read().splitlines()
|
||||||
|
assert len(lines) == 1000
|
||||||
|
self.human2integer_dict = dict()
|
||||||
|
for line in lines:
|
||||||
|
value, key = line.split(":")
|
||||||
|
self.human2integer_dict[key] = int(value)
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
with open(self.txt_filelist, "r") as f:
|
||||||
|
self.relpaths = f.read().splitlines()
|
||||||
|
l1 = len(self.relpaths)
|
||||||
|
self.relpaths = self._filter_relpaths(self.relpaths)
|
||||||
|
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
|
||||||
|
|
||||||
|
self.synsets = [p.split("/")[0] for p in self.relpaths]
|
||||||
|
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
|
||||||
|
|
||||||
|
unique_synsets = np.unique(self.synsets)
|
||||||
|
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
|
||||||
|
if not self.keep_orig_class_label:
|
||||||
|
self.class_labels = [class_dict[s] for s in self.synsets]
|
||||||
|
else:
|
||||||
|
self.class_labels = [self.synset2idx[s] for s in self.synsets]
|
||||||
|
|
||||||
|
with open(self.human_dict, "r") as f:
|
||||||
|
human_dict = f.read().splitlines()
|
||||||
|
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
|
||||||
|
|
||||||
|
self.human_labels = [human_dict[s] for s in self.synsets]
|
||||||
|
|
||||||
|
labels = {
|
||||||
|
"relpath": np.array(self.relpaths),
|
||||||
|
"synsets": np.array(self.synsets),
|
||||||
|
"class_label": np.array(self.class_labels),
|
||||||
|
"human_label": np.array(self.human_labels),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.process_images:
|
||||||
|
self.size = retrieve(self.config, "size", default=256)
|
||||||
|
self.data = ImagePaths(self.abspaths,
|
||||||
|
labels=labels,
|
||||||
|
size=self.size,
|
||||||
|
random_crop=self.random_crop,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.data = self.abspaths
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetTrain(ImageNetBase):
|
||||||
|
NAME = "ILSVRC2012_train"
|
||||||
|
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||||
|
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
|
||||||
|
FILES = [
|
||||||
|
"ILSVRC2012_img_train.tar",
|
||||||
|
]
|
||||||
|
SIZES = [
|
||||||
|
147897477120,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, process_images=True, data_root=None, **kwargs):
|
||||||
|
self.process_images = process_images
|
||||||
|
self.data_root = data_root
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
if self.data_root:
|
||||||
|
self.root = os.path.join(self.data_root, self.NAME)
|
||||||
|
else:
|
||||||
|
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||||
|
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||||
|
|
||||||
|
self.datadir = os.path.join(self.root, "data")
|
||||||
|
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||||
|
self.expected_length = 1281167
|
||||||
|
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
|
||||||
|
default=True)
|
||||||
|
if not tdu.is_prepared(self.root):
|
||||||
|
# prep
|
||||||
|
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||||
|
|
||||||
|
datadir = self.datadir
|
||||||
|
if not os.path.exists(datadir):
|
||||||
|
path = os.path.join(self.root, self.FILES[0])
|
||||||
|
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||||
|
import academictorrents as at
|
||||||
|
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||||
|
assert atpath == path
|
||||||
|
|
||||||
|
print("Extracting {} to {}".format(path, datadir))
|
||||||
|
os.makedirs(datadir, exist_ok=True)
|
||||||
|
with tarfile.open(path, "r:") as tar:
|
||||||
|
tar.extractall(path=datadir)
|
||||||
|
|
||||||
|
print("Extracting sub-tars.")
|
||||||
|
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
|
||||||
|
for subpath in tqdm(subpaths):
|
||||||
|
subdir = subpath[:-len(".tar")]
|
||||||
|
os.makedirs(subdir, exist_ok=True)
|
||||||
|
with tarfile.open(subpath, "r:") as tar:
|
||||||
|
tar.extractall(path=subdir)
|
||||||
|
|
||||||
|
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||||
|
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||||
|
filelist = sorted(filelist)
|
||||||
|
filelist = "\n".join(filelist)+"\n"
|
||||||
|
with open(self.txt_filelist, "w") as f:
|
||||||
|
f.write(filelist)
|
||||||
|
|
||||||
|
tdu.mark_prepared(self.root)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetValidation(ImageNetBase):
|
||||||
|
NAME = "ILSVRC2012_validation"
|
||||||
|
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
|
||||||
|
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
|
||||||
|
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
|
||||||
|
FILES = [
|
||||||
|
"ILSVRC2012_img_val.tar",
|
||||||
|
"validation_synset.txt",
|
||||||
|
]
|
||||||
|
SIZES = [
|
||||||
|
6744924160,
|
||||||
|
1950000,
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, process_images=True, data_root=None, **kwargs):
|
||||||
|
self.data_root = data_root
|
||||||
|
self.process_images = process_images
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def _prepare(self):
|
||||||
|
if self.data_root:
|
||||||
|
self.root = os.path.join(self.data_root, self.NAME)
|
||||||
|
else:
|
||||||
|
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
|
||||||
|
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
|
||||||
|
self.datadir = os.path.join(self.root, "data")
|
||||||
|
self.txt_filelist = os.path.join(self.root, "filelist.txt")
|
||||||
|
self.expected_length = 50000
|
||||||
|
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
|
||||||
|
default=False)
|
||||||
|
if not tdu.is_prepared(self.root):
|
||||||
|
# prep
|
||||||
|
print("Preparing dataset {} in {}".format(self.NAME, self.root))
|
||||||
|
|
||||||
|
datadir = self.datadir
|
||||||
|
if not os.path.exists(datadir):
|
||||||
|
path = os.path.join(self.root, self.FILES[0])
|
||||||
|
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
|
||||||
|
import academictorrents as at
|
||||||
|
atpath = at.get(self.AT_HASH, datastore=self.root)
|
||||||
|
assert atpath == path
|
||||||
|
|
||||||
|
print("Extracting {} to {}".format(path, datadir))
|
||||||
|
os.makedirs(datadir, exist_ok=True)
|
||||||
|
with tarfile.open(path, "r:") as tar:
|
||||||
|
tar.extractall(path=datadir)
|
||||||
|
|
||||||
|
vspath = os.path.join(self.root, self.FILES[1])
|
||||||
|
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
|
||||||
|
download(self.VS_URL, vspath)
|
||||||
|
|
||||||
|
with open(vspath, "r") as f:
|
||||||
|
synset_dict = f.read().splitlines()
|
||||||
|
synset_dict = dict(line.split() for line in synset_dict)
|
||||||
|
|
||||||
|
print("Reorganizing into synset folders")
|
||||||
|
synsets = np.unique(list(synset_dict.values()))
|
||||||
|
for s in synsets:
|
||||||
|
os.makedirs(os.path.join(datadir, s), exist_ok=True)
|
||||||
|
for k, v in synset_dict.items():
|
||||||
|
src = os.path.join(datadir, k)
|
||||||
|
dst = os.path.join(datadir, v)
|
||||||
|
shutil.move(src, dst)
|
||||||
|
|
||||||
|
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
|
||||||
|
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
|
||||||
|
filelist = sorted(filelist)
|
||||||
|
filelist = "\n".join(filelist)+"\n"
|
||||||
|
with open(self.txt_filelist, "w") as f:
|
||||||
|
f.write(filelist)
|
||||||
|
|
||||||
|
tdu.mark_prepared(self.root)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetSR(Dataset):
|
||||||
|
def __init__(self, size=None,
|
||||||
|
degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1.,
|
||||||
|
random_crop=True):
|
||||||
|
"""
|
||||||
|
Imagenet Superresolution Dataloader
|
||||||
|
Performs following ops in order:
|
||||||
|
1. crops a crop of size s from image either as random or center crop
|
||||||
|
2. resizes crop to size with cv2.area_interpolation
|
||||||
|
3. degrades resized crop with degradation_fn
|
||||||
|
|
||||||
|
:param size: resizing to size after cropping
|
||||||
|
:param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light
|
||||||
|
:param downscale_f: Low Resolution Downsample factor
|
||||||
|
:param min_crop_f: determines crop size s,
|
||||||
|
where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f)
|
||||||
|
:param max_crop_f: ""
|
||||||
|
:param data_root:
|
||||||
|
:param random_crop:
|
||||||
|
"""
|
||||||
|
self.base = self.get_base()
|
||||||
|
assert size
|
||||||
|
assert (size / downscale_f).is_integer()
|
||||||
|
self.size = size
|
||||||
|
self.LR_size = int(size / downscale_f)
|
||||||
|
self.min_crop_f = min_crop_f
|
||||||
|
self.max_crop_f = max_crop_f
|
||||||
|
assert(max_crop_f <= 1.)
|
||||||
|
self.center_crop = not random_crop
|
||||||
|
|
||||||
|
self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
|
self.pil_interpolation = False # gets reset later if incase interp_op is from pillow
|
||||||
|
|
||||||
|
if degradation == "bsrgan":
|
||||||
|
self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f)
|
||||||
|
|
||||||
|
elif degradation == "bsrgan_light":
|
||||||
|
self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f)
|
||||||
|
|
||||||
|
else:
|
||||||
|
interpolation_fn = {
|
||||||
|
"cv_nearest": cv2.INTER_NEAREST,
|
||||||
|
"cv_bilinear": cv2.INTER_LINEAR,
|
||||||
|
"cv_bicubic": cv2.INTER_CUBIC,
|
||||||
|
"cv_area": cv2.INTER_AREA,
|
||||||
|
"cv_lanczos": cv2.INTER_LANCZOS4,
|
||||||
|
"pil_nearest": PIL.Image.NEAREST,
|
||||||
|
"pil_bilinear": PIL.Image.BILINEAR,
|
||||||
|
"pil_bicubic": PIL.Image.BICUBIC,
|
||||||
|
"pil_box": PIL.Image.BOX,
|
||||||
|
"pil_hamming": PIL.Image.HAMMING,
|
||||||
|
"pil_lanczos": PIL.Image.LANCZOS,
|
||||||
|
}[degradation]
|
||||||
|
|
||||||
|
self.pil_interpolation = degradation.startswith("pil_")
|
||||||
|
|
||||||
|
if self.pil_interpolation:
|
||||||
|
self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size,
|
||||||
|
interpolation=interpolation_fn)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.base)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = self.base[i]
|
||||||
|
image = Image.open(example["file_path_"])
|
||||||
|
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
|
||||||
|
min_side_len = min(image.shape[:2])
|
||||||
|
crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None)
|
||||||
|
crop_side_len = int(crop_side_len)
|
||||||
|
|
||||||
|
if self.center_crop:
|
||||||
|
self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len)
|
||||||
|
|
||||||
|
image = self.cropper(image=image)["image"]
|
||||||
|
image = self.image_rescaler(image=image)["image"]
|
||||||
|
|
||||||
|
if self.pil_interpolation:
|
||||||
|
image_pil = PIL.Image.fromarray(image)
|
||||||
|
LR_image = self.degradation_process(image_pil)
|
||||||
|
LR_image = np.array(LR_image).astype(np.uint8)
|
||||||
|
|
||||||
|
else:
|
||||||
|
LR_image = self.degradation_process(image=image)["image"]
|
||||||
|
|
||||||
|
example["image"] = (image/127.5 - 1.0).astype(np.float32)
|
||||||
|
example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32)
|
||||||
|
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetSRTrain(ImageNetSR):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def get_base(self):
|
||||||
|
with open("data/imagenet_train_hr_indices.p", "rb") as f:
|
||||||
|
indices = pickle.load(f)
|
||||||
|
dset = ImageNetTrain(process_images=False,)
|
||||||
|
return Subset(dset, indices)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageNetSRValidation(ImageNetSR):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def get_base(self):
|
||||||
|
with open("data/imagenet_val_hr_indices.p", "rb") as f:
|
||||||
|
indices = pickle.load(f)
|
||||||
|
dset = ImageNetValidation(process_images=False,)
|
||||||
|
return Subset(dset, indices)
|
|
@ -0,0 +1,168 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
import glob
|
||||||
|
|
||||||
|
import random
|
||||||
|
|
||||||
|
PIL.Image.MAX_IMAGE_PIXELS = 933120000
|
||||||
|
|
||||||
|
class LocalBase(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
data_root='./danbooru-aesthetic',
|
||||||
|
size=512,
|
||||||
|
interpolation="bicubic",
|
||||||
|
flip_p=0.5,
|
||||||
|
crop=True,
|
||||||
|
shuffle=False,
|
||||||
|
mode='train',
|
||||||
|
val_split=64,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.shuffle=shuffle
|
||||||
|
self.crop = crop
|
||||||
|
|
||||||
|
print('Fetching data.')
|
||||||
|
|
||||||
|
ext = ['png', 'jpg', 'jpeg', 'bmp']
|
||||||
|
self.image_files = []
|
||||||
|
[self.image_files.extend(glob.glob(f'{data_root}/img/' + '*.' + e)) for e in ext]
|
||||||
|
if mode == 'val':
|
||||||
|
self.image_files = self.image_files[:len(self.image_files)//val_split]
|
||||||
|
|
||||||
|
print('Constructing image-caption map.')
|
||||||
|
|
||||||
|
self.examples = {}
|
||||||
|
self.hashes = []
|
||||||
|
for i in self.image_files:
|
||||||
|
hash = i[len(f'{data_root}/img/'):].split('.')[0]
|
||||||
|
self.examples[hash] = {
|
||||||
|
'image': i,
|
||||||
|
'text': f'{data_root}/txt/{hash}.txt'
|
||||||
|
}
|
||||||
|
self.hashes.append(hash)
|
||||||
|
|
||||||
|
print(f'image-caption map has {len(self.examples.keys())} examples')
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||||
|
"bilinear": PIL.Image.BILINEAR,
|
||||||
|
"bicubic": PIL.Image.BICUBIC,
|
||||||
|
"lanczos": PIL.Image.LANCZOS,
|
||||||
|
}[interpolation]
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
def random_sample(self):
|
||||||
|
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
||||||
|
|
||||||
|
def sequential_sample(self, i):
|
||||||
|
if i >= self.__len__() - 1:
|
||||||
|
return self.__getitem__(0)
|
||||||
|
return self.__getitem__(i + 1)
|
||||||
|
|
||||||
|
def skip_sample(self, i):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_caption(self, i):
|
||||||
|
example = self.examples[self.hashes[i]]
|
||||||
|
caption = open(example['text'], 'r').read()
|
||||||
|
caption = caption.replace(' ', ' ').replace('\n', ' ').lstrip().rstrip()
|
||||||
|
return caption
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.image_files)
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example_ret = {}
|
||||||
|
try:
|
||||||
|
image_file = self.examples[self.hashes[i]]['image']
|
||||||
|
image = Image.open(image_file)
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
except (OSError, ValueError) as e:
|
||||||
|
print(f'Error with {image_file} -- skipping {i}')
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
caption = self.get_caption(i)
|
||||||
|
if caption == None:
|
||||||
|
raise ValueError
|
||||||
|
except (OSError, ValueError) as e:
|
||||||
|
print(f'Error with caption of {image_file} -- skipping {i}')
|
||||||
|
return self.skip_sample(i)
|
||||||
|
|
||||||
|
example_ret['caption'] = caption
|
||||||
|
|
||||||
|
# default to score-sde preprocessing
|
||||||
|
if self.crop:
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = img.shape[0], img.shape[1]
|
||||||
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||||
|
(w - crop) // 2:(w + crop) // 2]
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
|
||||||
|
if self.size is not None:
|
||||||
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||||
|
|
||||||
|
image = self.flip(image)
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
example_ret["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
return example_ret
|
||||||
|
|
||||||
|
def get_image(self, i):
|
||||||
|
try:
|
||||||
|
image_file = self.examples[self.hashes[i]]['image']
|
||||||
|
image = Image.open(image_file)
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Error with {image_file} -- skipping {i}')
|
||||||
|
return self.skip_sample(i)
|
||||||
|
|
||||||
|
# default to score-sde preprocessing
|
||||||
|
if self.crop:
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = img.shape[0], img.shape[1]
|
||||||
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||||
|
(w - crop) // 2:(w + crop) // 2]
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
|
||||||
|
if self.size is not None:
|
||||||
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||||
|
|
||||||
|
image = self.flip(image)
|
||||||
|
return image
|
||||||
|
|
||||||
|
"""
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = LocalBase('./danbooru-aesthetic', size=512, crop=False, mode='val')
|
||||||
|
print(dataset.__len__())
|
||||||
|
example = dataset.__getitem__(0)
|
||||||
|
print(dataset.hashes[0])
|
||||||
|
print(example['caption'])
|
||||||
|
image = example['image']
|
||||||
|
image = ((image + 1) * 127.5).astype(np.uint8)
|
||||||
|
image = Image.fromarray(image)
|
||||||
|
image.save('example.png')
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
from tqdm import tqdm
|
||||||
|
if __name__ == "__main__":
|
||||||
|
dataset = LocalBase('../glide-finetune/touhou-portrait-aesthetic', size=512)
|
||||||
|
for i in tqdm(range(dataset.__len__())):
|
||||||
|
image = dataset.get_image(i)
|
||||||
|
if image == None:
|
||||||
|
continue
|
||||||
|
image.save(f'./danbooru-aesthetic/img/{dataset.hashes[i]}.png')
|
||||||
|
with open(f'./danbooru-aesthetic/txt/{dataset.hashes[i]}.txt', 'w') as f:
|
||||||
|
f.write(dataset.get_caption(i))
|
||||||
|
|
||||||
|
"""
|
|
@ -0,0 +1,92 @@
|
||||||
|
import os
|
||||||
|
import numpy as np
|
||||||
|
import PIL
|
||||||
|
from PIL import Image
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNBase(Dataset):
|
||||||
|
def __init__(self,
|
||||||
|
txt_file,
|
||||||
|
data_root,
|
||||||
|
size=None,
|
||||||
|
interpolation="bicubic",
|
||||||
|
flip_p=0.5
|
||||||
|
):
|
||||||
|
self.data_paths = txt_file
|
||||||
|
self.data_root = data_root
|
||||||
|
with open(self.data_paths, "r") as f:
|
||||||
|
self.image_paths = f.read().splitlines()
|
||||||
|
self._length = len(self.image_paths)
|
||||||
|
self.labels = {
|
||||||
|
"relative_file_path_": [l for l in self.image_paths],
|
||||||
|
"file_path_": [os.path.join(self.data_root, l)
|
||||||
|
for l in self.image_paths],
|
||||||
|
}
|
||||||
|
|
||||||
|
self.size = size
|
||||||
|
self.interpolation = {"linear": PIL.Image.LINEAR,
|
||||||
|
"bilinear": PIL.Image.BILINEAR,
|
||||||
|
"bicubic": PIL.Image.BICUBIC,
|
||||||
|
"lanczos": PIL.Image.LANCZOS,
|
||||||
|
}[interpolation]
|
||||||
|
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self._length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
example = dict((k, self.labels[k][i]) for k in self.labels)
|
||||||
|
image = Image.open(example["file_path_"])
|
||||||
|
if not image.mode == "RGB":
|
||||||
|
image = image.convert("RGB")
|
||||||
|
|
||||||
|
# default to score-sde preprocessing
|
||||||
|
img = np.array(image).astype(np.uint8)
|
||||||
|
crop = min(img.shape[0], img.shape[1])
|
||||||
|
h, w, = img.shape[0], img.shape[1]
|
||||||
|
img = img[(h - crop) // 2:(h + crop) // 2,
|
||||||
|
(w - crop) // 2:(w + crop) // 2]
|
||||||
|
|
||||||
|
image = Image.fromarray(img)
|
||||||
|
if self.size is not None:
|
||||||
|
image = image.resize((self.size, self.size), resample=self.interpolation)
|
||||||
|
|
||||||
|
image = self.flip(image)
|
||||||
|
image = np.array(image).astype(np.uint8)
|
||||||
|
example["image"] = (image / 127.5 - 1.0).astype(np.float32)
|
||||||
|
return example
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNChurchesTrain(LSUNBase):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNChurchesValidation(LSUNBase):
|
||||||
|
def __init__(self, flip_p=0., **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
|
||||||
|
flip_p=flip_p, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNBedroomsTrain(LSUNBase):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNBedroomsValidation(LSUNBase):
|
||||||
|
def __init__(self, flip_p=0.0, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
|
||||||
|
flip_p=flip_p, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNCatsTrain(LSUNBase):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LSUNCatsValidation(LSUNBase):
|
||||||
|
def __init__(self, flip_p=0., **kwargs):
|
||||||
|
super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
|
||||||
|
flip_p=flip_p, **kwargs)
|
|
@ -0,0 +1,98 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaWarmUpCosineScheduler:
|
||||||
|
"""
|
||||||
|
note: use with a base_lr of 1.0
|
||||||
|
"""
|
||||||
|
def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
|
||||||
|
self.lr_warm_up_steps = warm_up_steps
|
||||||
|
self.lr_start = lr_start
|
||||||
|
self.lr_min = lr_min
|
||||||
|
self.lr_max = lr_max
|
||||||
|
self.lr_max_decay_steps = max_decay_steps
|
||||||
|
self.last_lr = 0.
|
||||||
|
self.verbosity_interval = verbosity_interval
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
if self.verbosity_interval > 0:
|
||||||
|
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||||
|
if n < self.lr_warm_up_steps:
|
||||||
|
lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
|
||||||
|
self.last_lr = lr
|
||||||
|
return lr
|
||||||
|
else:
|
||||||
|
t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
|
||||||
|
t = min(t, 1.0)
|
||||||
|
lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
|
||||||
|
1 + np.cos(t * np.pi))
|
||||||
|
self.last_lr = lr
|
||||||
|
return lr
|
||||||
|
|
||||||
|
def __call__(self, n, **kwargs):
|
||||||
|
return self.schedule(n,**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaWarmUpCosineScheduler2:
|
||||||
|
"""
|
||||||
|
supports repeated iterations, configurable via lists
|
||||||
|
note: use with a base_lr of 1.0.
|
||||||
|
"""
|
||||||
|
def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
|
||||||
|
assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
|
||||||
|
self.lr_warm_up_steps = warm_up_steps
|
||||||
|
self.f_start = f_start
|
||||||
|
self.f_min = f_min
|
||||||
|
self.f_max = f_max
|
||||||
|
self.cycle_lengths = cycle_lengths
|
||||||
|
self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
|
||||||
|
self.last_f = 0.
|
||||||
|
self.verbosity_interval = verbosity_interval
|
||||||
|
|
||||||
|
def find_in_interval(self, n):
|
||||||
|
interval = 0
|
||||||
|
for cl in self.cum_cycles[1:]:
|
||||||
|
if n <= cl:
|
||||||
|
return interval
|
||||||
|
interval += 1
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
cycle = self.find_in_interval(n)
|
||||||
|
n = n - self.cum_cycles[cycle]
|
||||||
|
if self.verbosity_interval > 0:
|
||||||
|
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||||
|
f"current cycle {cycle}")
|
||||||
|
if n < self.lr_warm_up_steps[cycle]:
|
||||||
|
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
|
||||||
|
t = min(t, 1.0)
|
||||||
|
f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
|
||||||
|
1 + np.cos(t * np.pi))
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
|
||||||
|
def __call__(self, n, **kwargs):
|
||||||
|
return self.schedule(n, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||||
|
|
||||||
|
def schedule(self, n, **kwargs):
|
||||||
|
cycle = self.find_in_interval(n)
|
||||||
|
n = n - self.cum_cycles[cycle]
|
||||||
|
if self.verbosity_interval > 0:
|
||||||
|
if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||||
|
f"current cycle {cycle}")
|
||||||
|
|
||||||
|
if n < self.lr_warm_up_steps[cycle]:
|
||||||
|
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
else:
|
||||||
|
f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
|
||||||
|
self.last_f = f
|
||||||
|
return f
|
||||||
|
|
|
@ -0,0 +1,443 @@
|
||||||
|
import torch
|
||||||
|
import pytorch_lightning as pl
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||||
|
|
||||||
|
from ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||||
|
from ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||||
|
|
||||||
|
from ldm.util import instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
class VQModel(pl.LightningModule):
|
||||||
|
def __init__(self,
|
||||||
|
ddconfig,
|
||||||
|
lossconfig,
|
||||||
|
n_embed,
|
||||||
|
embed_dim,
|
||||||
|
ckpt_path=None,
|
||||||
|
ignore_keys=[],
|
||||||
|
image_key="image",
|
||||||
|
colorize_nlabels=None,
|
||||||
|
monitor=None,
|
||||||
|
batch_resize_range=None,
|
||||||
|
scheduler_config=None,
|
||||||
|
lr_g_factor=1.0,
|
||||||
|
remap=None,
|
||||||
|
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||||
|
use_ema=False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.n_embed = n_embed
|
||||||
|
self.image_key = image_key
|
||||||
|
self.encoder = Encoder(**ddconfig)
|
||||||
|
self.decoder = Decoder(**ddconfig)
|
||||||
|
self.loss = instantiate_from_config(lossconfig)
|
||||||
|
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||||
|
remap=remap,
|
||||||
|
sane_index_shape=sane_index_shape)
|
||||||
|
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
|
if colorize_nlabels is not None:
|
||||||
|
assert type(colorize_nlabels)==int
|
||||||
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||||
|
if monitor is not None:
|
||||||
|
self.monitor = monitor
|
||||||
|
self.batch_resize_range = batch_resize_range
|
||||||
|
if self.batch_resize_range is not None:
|
||||||
|
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||||
|
|
||||||
|
self.use_ema = use_ema
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema = LitEma(self)
|
||||||
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
self.scheduler_config = scheduler_config
|
||||||
|
self.lr_g_factor = lr_g_factor
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def ema_scope(self, context=None):
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema.store(self.parameters())
|
||||||
|
self.model_ema.copy_to(self)
|
||||||
|
if context is not None:
|
||||||
|
print(f"{context}: Switched to EMA weights")
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema.restore(self.parameters())
|
||||||
|
if context is not None:
|
||||||
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del sd[k]
|
||||||
|
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||||
|
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||||
|
if len(missing) > 0:
|
||||||
|
print(f"Missing Keys: {missing}")
|
||||||
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
|
def on_train_batch_end(self, *args, **kwargs):
|
||||||
|
if self.use_ema:
|
||||||
|
self.model_ema(self)
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
return quant, emb_loss, info
|
||||||
|
|
||||||
|
def encode_to_prequant(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def decode(self, quant):
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def decode_code(self, code_b):
|
||||||
|
quant_b = self.quantize.embed_code(code_b)
|
||||||
|
dec = self.decode(quant_b)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, return_pred_indices=False):
|
||||||
|
quant, diff, (_,_,ind) = self.encode(input)
|
||||||
|
dec = self.decode(quant)
|
||||||
|
if return_pred_indices:
|
||||||
|
return dec, diff, ind
|
||||||
|
return dec, diff
|
||||||
|
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||||
|
if self.batch_resize_range is not None:
|
||||||
|
lower_size = self.batch_resize_range[0]
|
||||||
|
upper_size = self.batch_resize_range[1]
|
||||||
|
if self.global_step <= 4:
|
||||||
|
# do the first few batches with max size to avoid later oom
|
||||||
|
new_resize = upper_size
|
||||||
|
else:
|
||||||
|
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||||
|
if new_resize != x.shape[2]:
|
||||||
|
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||||
|
x = x.detach()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
# https://github.com/pytorch/pytorch/issues/37142
|
||||||
|
# try not to fool the heuristics
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# autoencode
|
||||||
|
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train",
|
||||||
|
predicted_indices=ind)
|
||||||
|
|
||||||
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||||
|
return aeloss
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# discriminator
|
||||||
|
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train")
|
||||||
|
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||||
|
return discloss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
log_dict = self._validation_step(batch, batch_idx)
|
||||||
|
with self.ema_scope():
|
||||||
|
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||||
|
return log_dict
|
||||||
|
|
||||||
|
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||||
|
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split="val"+suffix,
|
||||||
|
predicted_indices=ind
|
||||||
|
)
|
||||||
|
|
||||||
|
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||||
|
self.global_step,
|
||||||
|
last_layer=self.get_last_layer(),
|
||||||
|
split="val"+suffix,
|
||||||
|
predicted_indices=ind
|
||||||
|
)
|
||||||
|
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||||
|
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||||
|
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||||
|
self.log(f"val{suffix}/aeloss", aeloss,
|
||||||
|
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||||
|
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||||
|
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||||
|
self.log_dict(log_dict_ae)
|
||||||
|
self.log_dict(log_dict_disc)
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
lr_d = self.learning_rate
|
||||||
|
lr_g = self.lr_g_factor*self.learning_rate
|
||||||
|
print("lr_d", lr_d)
|
||||||
|
print("lr_g", lr_g)
|
||||||
|
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||||
|
list(self.decoder.parameters())+
|
||||||
|
list(self.quantize.parameters())+
|
||||||
|
list(self.quant_conv.parameters())+
|
||||||
|
list(self.post_quant_conv.parameters()),
|
||||||
|
lr=lr_g, betas=(0.5, 0.9))
|
||||||
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||||
|
lr=lr_d, betas=(0.5, 0.9))
|
||||||
|
|
||||||
|
if self.scheduler_config is not None:
|
||||||
|
scheduler = instantiate_from_config(self.scheduler_config)
|
||||||
|
|
||||||
|
print("Setting up LambdaLR scheduler...")
|
||||||
|
scheduler = [
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||||
|
'interval': 'step',
|
||||||
|
'frequency': 1
|
||||||
|
},
|
||||||
|
]
|
||||||
|
return [opt_ae, opt_disc], scheduler
|
||||||
|
return [opt_ae, opt_disc], []
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
x = x.to(self.device)
|
||||||
|
if only_inputs:
|
||||||
|
log["inputs"] = x
|
||||||
|
return log
|
||||||
|
xrec, _ = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
# colorize with random projection
|
||||||
|
assert xrec.shape[1] > 3
|
||||||
|
x = self.to_rgb(x)
|
||||||
|
xrec = self.to_rgb(xrec)
|
||||||
|
log["inputs"] = x
|
||||||
|
log["reconstructions"] = xrec
|
||||||
|
if plot_ema:
|
||||||
|
with self.ema_scope():
|
||||||
|
xrec_ema, _ = self(x)
|
||||||
|
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||||
|
log["reconstructions_ema"] = xrec_ema
|
||||||
|
return log
|
||||||
|
|
||||||
|
def to_rgb(self, x):
|
||||||
|
assert self.image_key == "segmentation"
|
||||||
|
if not hasattr(self, "colorize"):
|
||||||
|
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||||
|
x = F.conv2d(x, weight=self.colorize)
|
||||||
|
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class VQModelInterface(VQModel):
|
||||||
|
def __init__(self, embed_dim, *args, **kwargs):
|
||||||
|
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
h = self.quant_conv(h)
|
||||||
|
return h
|
||||||
|
|
||||||
|
def decode(self, h, force_not_quantize=False):
|
||||||
|
# also go through quantization layer
|
||||||
|
if not force_not_quantize:
|
||||||
|
quant, emb_loss, info = self.quantize(h)
|
||||||
|
else:
|
||||||
|
quant = h
|
||||||
|
quant = self.post_quant_conv(quant)
|
||||||
|
dec = self.decoder(quant)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
|
||||||
|
class AutoencoderKL(pl.LightningModule):
|
||||||
|
def __init__(self,
|
||||||
|
ddconfig,
|
||||||
|
lossconfig,
|
||||||
|
embed_dim,
|
||||||
|
ckpt_path=None,
|
||||||
|
ignore_keys=[],
|
||||||
|
image_key="image",
|
||||||
|
colorize_nlabels=None,
|
||||||
|
monitor=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.image_key = image_key
|
||||||
|
self.encoder = Encoder(**ddconfig)
|
||||||
|
self.decoder = Decoder(**ddconfig)
|
||||||
|
self.loss = instantiate_from_config(lossconfig)
|
||||||
|
assert ddconfig["double_z"]
|
||||||
|
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||||
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
if colorize_nlabels is not None:
|
||||||
|
assert type(colorize_nlabels)==int
|
||||||
|
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||||
|
if monitor is not None:
|
||||||
|
self.monitor = monitor
|
||||||
|
if ckpt_path is not None:
|
||||||
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
|
|
||||||
|
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||||
|
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||||
|
keys = list(sd.keys())
|
||||||
|
for k in keys:
|
||||||
|
for ik in ignore_keys:
|
||||||
|
if k.startswith(ik):
|
||||||
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
|
del sd[k]
|
||||||
|
self.load_state_dict(sd, strict=False)
|
||||||
|
print(f"Restored from {path}")
|
||||||
|
|
||||||
|
def encode(self, x):
|
||||||
|
h = self.encoder(x)
|
||||||
|
moments = self.quant_conv(h)
|
||||||
|
posterior = DiagonalGaussianDistribution(moments)
|
||||||
|
return posterior
|
||||||
|
|
||||||
|
def decode(self, z):
|
||||||
|
z = self.post_quant_conv(z)
|
||||||
|
dec = self.decoder(z)
|
||||||
|
return dec
|
||||||
|
|
||||||
|
def forward(self, input, sample_posterior=True):
|
||||||
|
posterior = self.encode(input)
|
||||||
|
if sample_posterior:
|
||||||
|
z = posterior.sample()
|
||||||
|
else:
|
||||||
|
z = posterior.mode()
|
||||||
|
dec = self.decode(z)
|
||||||
|
return dec, posterior
|
||||||
|
|
||||||
|
def get_input(self, batch, k):
|
||||||
|
x = batch[k]
|
||||||
|
if len(x.shape) == 3:
|
||||||
|
x = x[..., None]
|
||||||
|
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||||
|
return x
|
||||||
|
|
||||||
|
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||||
|
inputs = self.get_input(batch, self.image_key)
|
||||||
|
reconstructions, posterior = self(inputs)
|
||||||
|
|
||||||
|
if optimizer_idx == 0:
|
||||||
|
# train encoder+decoder+logvar
|
||||||
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train")
|
||||||
|
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||||
|
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||||
|
return aeloss
|
||||||
|
|
||||||
|
if optimizer_idx == 1:
|
||||||
|
# train the discriminator
|
||||||
|
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="train")
|
||||||
|
|
||||||
|
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||||
|
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||||
|
return discloss
|
||||||
|
|
||||||
|
def validation_step(self, batch, batch_idx):
|
||||||
|
inputs = self.get_input(batch, self.image_key)
|
||||||
|
reconstructions, posterior = self(inputs)
|
||||||
|
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="val")
|
||||||
|
|
||||||
|
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||||
|
last_layer=self.get_last_layer(), split="val")
|
||||||
|
|
||||||
|
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||||
|
self.log_dict(log_dict_ae)
|
||||||
|
self.log_dict(log_dict_disc)
|
||||||
|
return self.log_dict
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
lr = self.learning_rate
|
||||||
|
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||||
|
list(self.decoder.parameters())+
|
||||||
|
list(self.quant_conv.parameters())+
|
||||||
|
list(self.post_quant_conv.parameters()),
|
||||||
|
lr=lr, betas=(0.5, 0.9))
|
||||||
|
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||||
|
lr=lr, betas=(0.5, 0.9))
|
||||||
|
return [opt_ae, opt_disc], []
|
||||||
|
|
||||||
|
def get_last_layer(self):
|
||||||
|
return self.decoder.conv_out.weight
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||||
|
log = dict()
|
||||||
|
x = self.get_input(batch, self.image_key)
|
||||||
|
x = x.to(self.device)
|
||||||
|
if not only_inputs:
|
||||||
|
xrec, posterior = self(x)
|
||||||
|
if x.shape[1] > 3:
|
||||||
|
# colorize with random projection
|
||||||
|
assert xrec.shape[1] > 3
|
||||||
|
x = self.to_rgb(x)
|
||||||
|
xrec = self.to_rgb(xrec)
|
||||||
|
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||||
|
log["reconstructions"] = xrec
|
||||||
|
log["inputs"] = x
|
||||||
|
return log
|
||||||
|
|
||||||
|
def to_rgb(self, x):
|
||||||
|
assert self.image_key == "segmentation"
|
||||||
|
if not hasattr(self, "colorize"):
|
||||||
|
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||||
|
x = F.conv2d(x, weight=self.colorize)
|
||||||
|
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityFirstStage(torch.nn.Module):
|
||||||
|
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||||
|
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def encode(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def decode(self, x, *args, **kwargs):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def quantize(self, x, *args, **kwargs):
|
||||||
|
if self.vq_interface:
|
||||||
|
return x, None, [None, None, None]
|
||||||
|
return x
|
||||||
|
|
||||||
|
def forward(self, x, *args, **kwargs):
|
||||||
|
return x
|