revup virtual diff target
2fc9b8d379
d98efa55fe6e64c65f3344a22cad8db1111240ff450a9b6fad
a7b41d90bd
|
@ -0,0 +1,44 @@
|
|||
name: CI
|
||||
run-name: ${{ github.actor }} is running Riffusion CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- 'main'
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened, ready_for_review]
|
||||
|
||||
jobs:
|
||||
riffusion-ci:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.9", "3.10"]
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Install system packages
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y ffmpeg libsndfile1
|
||||
|
||||
- name: Install pip packages from requirements.txt
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r requirements.txt
|
||||
|
||||
- name: Install pip packages from dev_requirements.txt
|
||||
run: |
|
||||
pip install -r dev_requirements.txt
|
||||
|
||||
- name: Test with unittest
|
||||
run: |
|
||||
RIFFUSION_TEST_DEVICE=cpu python -m unittest test/*_test.py
|
|
@ -0,0 +1,141 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# VSCode
|
||||
.vscode
|
||||
|
||||
# Cog
|
||||
.cog/
|
||||
|
||||
# Random stuff I don't care about
|
||||
.graveyard/
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# OSX cruft
|
||||
.DS_Store
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
|
@ -0,0 +1,6 @@
|
|||
@article{Forsgren_Martiros_2022,
|
||||
author = {Forsgren, Seth* and Martiros, Hayk*},
|
||||
title = {{Riffusion - Stable diffusion for real-time music generation}},
|
||||
url = {https://riffusion.com/about},
|
||||
year = {2022}
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
Copyright 2022 Hayk Martiros and Seth Forsgren
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||
sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial
|
||||
portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
|
||||
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
|
||||
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
@ -0,0 +1,220 @@
|
|||
# :guitar: Riffusion
|
||||
|
||||
<a href="https://github.com/riffusion/riffusion/actions/workflows/ci.yml?query=branch%3Amain"><img alt="CI status" src="https://github.com/riffusion/riffusion/actions/workflows/ci.yml/badge.svg" /></a>
|
||||
<img alt="Python 3.9 | 3.10" src="https://img.shields.io/badge/Python-3.9%20%7C%203.10-blue" />
|
||||
<a href="https://github.com/riffusion/riffusion/tree/main/LICENSE"><img alt="MIT License" src="https://img.shields.io/badge/License-MIT-yellowgreen" /></a>
|
||||
|
||||
Riffusion is a library for real-time music and audio generation with stable diffusion.
|
||||
|
||||
Read about it at https://www.riffusion.com/about and try it at https://www.riffusion.com/.
|
||||
|
||||
This is the core repository for riffusion image and audio processing code.
|
||||
|
||||
* Diffusion pipeline that performs prompt interpolation combined with image conditioning
|
||||
* Conversions between spectrogram images and audio clips
|
||||
* Command-line interface for common tasks
|
||||
* Interactive app using streamlit
|
||||
* Flask server to provide model inference via API
|
||||
* Various third party integrations
|
||||
|
||||
Related repositories:
|
||||
* Web app: https://github.com/riffusion/riffusion-app
|
||||
* Model checkpoint: https://huggingface.co/riffusion/riffusion-model-v1
|
||||
|
||||
## Citation
|
||||
|
||||
If you build on this work, please cite it as follows:
|
||||
|
||||
```
|
||||
@article{Forsgren_Martiros_2022,
|
||||
author = {Forsgren, Seth* and Martiros, Hayk*},
|
||||
title = {{Riffusion - Stable diffusion for real-time music generation}},
|
||||
url = {https://riffusion.com/about},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
## Install
|
||||
|
||||
Tested in CI with Python 3.9 and 3.10.
|
||||
|
||||
It's highly recommended to set up a virtual Python environment with `conda` or `virtualenv`:
|
||||
```
|
||||
conda create --name riffusion python=3.9
|
||||
conda activate riffusion
|
||||
```
|
||||
|
||||
Install Python dependencies:
|
||||
```
|
||||
python -m pip install -r requirements.txt
|
||||
```
|
||||
|
||||
In order to use audio formats other than WAV, [ffmpeg](https://ffmpeg.org/download.html) is required.
|
||||
```
|
||||
sudo apt-get install ffmpeg # linux
|
||||
brew install ffmpeg # mac
|
||||
conda install -c conda-forge ffmpeg # conda
|
||||
```
|
||||
|
||||
If torchaudio has no backend, you may need to install `libsndfile`. See [this issue](https://github.com/riffusion/riffusion/issues/12).
|
||||
|
||||
If you have an issue, try upgrading [diffusers](https://github.com/huggingface/diffusers). Tested with 0.9 - 0.11.
|
||||
|
||||
Guides:
|
||||
* [Simple Install Guide for Windows](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/)
|
||||
|
||||
## Backends
|
||||
|
||||
### CPU
|
||||
`cpu` is supported but is quite slow.
|
||||
|
||||
### CUDA
|
||||
`cuda` is the recommended and most performant backend.
|
||||
|
||||
To use with CUDA, make sure you have torch and torchaudio installed with CUDA support. See the
|
||||
[install guide](https://pytorch.org/get-started/locally/) or
|
||||
[stable wheels](https://download.pytorch.org/whl/torch_stable.html).
|
||||
|
||||
To generate audio in real-time, you need a GPU that can run stable diffusion with approximately 50
|
||||
steps in under five seconds, such as a 3090 or A10G.
|
||||
|
||||
Test availability with:
|
||||
|
||||
```python3
|
||||
import torch
|
||||
torch.cuda.is_available()
|
||||
```
|
||||
|
||||
### MPS
|
||||
The `mps` backend on Apple Silicon is supported for inference but some operations fall back to CPU,
|
||||
particularly for audio processing. You may need to set
|
||||
`PYTORCH_ENABLE_MPS_FALLBACK=1`.
|
||||
|
||||
In addition, this backend is not deterministic.
|
||||
|
||||
Test availability with:
|
||||
|
||||
```python3
|
||||
import torch
|
||||
torch.backends.mps.is_available()
|
||||
```
|
||||
|
||||
## Command-line interface
|
||||
|
||||
Riffusion comes with a command line interface for performing common tasks.
|
||||
|
||||
See available commands:
|
||||
```
|
||||
python -m riffusion.cli -h
|
||||
```
|
||||
|
||||
Get help for a specific command:
|
||||
```
|
||||
python -m riffusion.cli image-to-audio -h
|
||||
```
|
||||
|
||||
Execute:
|
||||
```
|
||||
python -m riffusion.cli image-to-audio --image spectrogram_image.png --audio clip.wav
|
||||
```
|
||||
|
||||
## Riffusion Playground
|
||||
|
||||
Riffusion contains a [streamlit](https://streamlit.io/) app for interactive use and exploration.
|
||||
|
||||
Run with:
|
||||
```
|
||||
python -m streamlit run riffusion/streamlit/playground.py --browser.serverAddress 127.0.0.1 --browser.serverPort 8501
|
||||
```
|
||||
|
||||
And access at http://127.0.0.1:8501/
|
||||
|
||||
<img alt="Riffusion Playground" style="width: 600px" src="https://i.imgur.com/OOMKBbT.png" />
|
||||
|
||||
## Run the model server
|
||||
|
||||
Riffusion can be run as a flask server that provides inference via API. This server enables the [web app](https://github.com/riffusion/riffusion-app) to run locally.
|
||||
|
||||
Run with:
|
||||
|
||||
```
|
||||
python -m riffusion.server --host 127.0.0.1 --port 3013
|
||||
```
|
||||
|
||||
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
|
||||
|
||||
Use the `--device` argument to specify the torch device to use.
|
||||
|
||||
The model endpoint is now available at `http://127.0.0.1:3013/run_inference` via POST request.
|
||||
|
||||
Example input (see [InferenceInput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L28) for the API):
|
||||
```
|
||||
{
|
||||
"alpha": 0.75,
|
||||
"num_inference_steps": 50,
|
||||
"seed_image_id": "og_beat",
|
||||
|
||||
"start": {
|
||||
"prompt": "church bells on sunday",
|
||||
"seed": 42,
|
||||
"denoising": 0.75,
|
||||
"guidance": 7.0
|
||||
},
|
||||
|
||||
"end": {
|
||||
"prompt": "jazz with piano",
|
||||
"seed": 123,
|
||||
"denoising": 0.75,
|
||||
"guidance": 7.0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-inference/blob/main/riffusion/datatypes.py#L54) for the API):
|
||||
```
|
||||
{
|
||||
"image": "< base64 encoded JPEG image >",
|
||||
"audio": "< base64 encoded MP3 clip >"
|
||||
}
|
||||
```
|
||||
|
||||
## Tests
|
||||
Tests live in the `test/` directory and are implemented with `unittest`.
|
||||
|
||||
To run all tests:
|
||||
```
|
||||
python -m unittest test/*_test.py
|
||||
```
|
||||
|
||||
To run a single test:
|
||||
```
|
||||
python -m unittest test.audio_to_image_test
|
||||
```
|
||||
|
||||
To preserve temporary outputs for debugging, set `RIFFUSION_TEST_DEBUG`:
|
||||
```
|
||||
RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test
|
||||
```
|
||||
|
||||
To run a single test case within a test:
|
||||
```
|
||||
python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo
|
||||
```
|
||||
|
||||
To run tests using a specific torch device, set `RIFFUSION_TEST_DEVICE`. Tests should pass with
|
||||
`cpu`, `cuda`, and `mps` backends.
|
||||
|
||||
## Development Guide
|
||||
Install additional packages for dev with `python -m pip install -r dev_requirements.txt`.
|
||||
|
||||
* Linter: `ruff`
|
||||
* Formatter: `black`
|
||||
* Type checker: `mypy`
|
||||
|
||||
These are configured in `pyproject.toml`.
|
||||
|
||||
The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR.
|
||||
|
||||
CI is run through GitHub Actions from `.github/workflows/ci.yml`.
|
||||
|
||||
Contributions are welcome through pull requests.
|
|
@ -0,0 +1,38 @@
|
|||
# Configuration for Cog ⚙️
|
||||
# Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
|
||||
|
||||
build:
|
||||
# set to true if your model requires a GPU
|
||||
gpu: true
|
||||
|
||||
# a list of ubuntu apt packages to install
|
||||
system_packages:
|
||||
- "ffmpeg"
|
||||
- "libsndfile1"
|
||||
|
||||
# python version in the form '3.8' or '3.8.12'
|
||||
python_version: "3.9"
|
||||
|
||||
# a list of packages in the format <package-name>==<version>
|
||||
python_packages:
|
||||
- "accelerate==0.15.0"
|
||||
- "argh==0.26.2"
|
||||
- "dacite==1.6.0"
|
||||
- "diffusers==0.10.2"
|
||||
- "flask_cors==3.0.10"
|
||||
- "flask==1.1.2"
|
||||
- "numpy==1.19.4"
|
||||
- "pillow==9.1.0"
|
||||
- "pydub==0.25.1"
|
||||
- "scipy==1.6.3"
|
||||
- "torch==1.13.0"
|
||||
- "torchaudio==0.13.0"
|
||||
- "transformers==4.25.1"
|
||||
|
||||
# commands run after the environment is setup
|
||||
# run:
|
||||
# - "echo env is ready!"
|
||||
# - "echo another command if needed"
|
||||
|
||||
# predict.py defines how predictions are run on your model
|
||||
predict: "integrations/cog_riffusion.py:RiffusionPredictor"
|
|
@ -0,0 +1,7 @@
|
|||
black
|
||||
ipdb
|
||||
mypy
|
||||
ruff
|
||||
types-Flask-Cors
|
||||
types-Pillow
|
||||
types-requests
|
|
@ -0,0 +1,23 @@
|
|||
# Integrations
|
||||
|
||||
This package contains integrations of Riffusion into third party apps and deployments.
|
||||
|
||||
## Baseten
|
||||
|
||||
[Baseten](https://baseten.com) is a platform for building and deploying machine learning models.
|
||||
|
||||
## Replicate
|
||||
|
||||
To run riffusion as a Cog model, first, [install Cog](https://github.com/replicate/cog) and
|
||||
download the model weights:
|
||||
|
||||
cog run python -m integrations.cog_riffusion --download_weights
|
||||
|
||||
Then you can run predictions:
|
||||
|
||||
cog predict -i prompt_a="funky synth solo"
|
||||
|
||||
You can also view the model on replicate [here](https://replicate.com/hmartiro/riffusion). Owners
|
||||
can push an updated version of the model like so:
|
||||
|
||||
cog push r8.im/hmartiro/riffusion
|
|
@ -0,0 +1,83 @@
|
|||
"""
|
||||
This file can be used to build a Truss for deployment with Baseten.
|
||||
If used, it should be renamed to model.py and placed alongside the other
|
||||
files from /riffusion in the standard /model directory of the Truss.
|
||||
|
||||
For more on the Truss file format, see https://truss.baseten.co/
|
||||
"""
|
||||
|
||||
import typing as T
|
||||
|
||||
import dacite
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from riffusion.datatypes import InferenceInput
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.server import compute_request
|
||||
|
||||
|
||||
class Model:
|
||||
"""
|
||||
Baseten Truss model class for riffusion.
|
||||
|
||||
See: https://truss.baseten.co/reference/structure#model.py
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._data_dir = kwargs["data_dir"]
|
||||
self._config = kwargs["config"]
|
||||
self._pipeline = None
|
||||
self._vae = None
|
||||
|
||||
self.checkpoint_name = "riffusion/riffusion-model-v1"
|
||||
|
||||
# Download entire seed image folder from huggingface hub
|
||||
self._seed_images_dir = snapshot_download(self.checkpoint_name, allow_patterns="*.png")
|
||||
|
||||
def load(self):
|
||||
"""
|
||||
Load the model. Guaranteed to be called before `predict`.
|
||||
"""
|
||||
self._pipeline = RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=self.checkpoint_name,
|
||||
use_traced_unet=True,
|
||||
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
|
||||
)
|
||||
|
||||
def preprocess(self, request: T.Dict) -> T.Dict:
|
||||
"""
|
||||
Incorporate pre-processing required by the model if desired here.
|
||||
|
||||
These might be feature transformations that are tightly coupled to the model.
|
||||
"""
|
||||
return request
|
||||
|
||||
def predict(self, request: T.Dict) -> T.Dict[str, T.List]:
|
||||
"""
|
||||
This is the main function that is called.
|
||||
"""
|
||||
assert self._pipeline is not None, "Model pipeline not loaded"
|
||||
|
||||
try:
|
||||
inputs = dacite.from_dict(InferenceInput, request)
|
||||
except dacite.exceptions.WrongTypeError as exception:
|
||||
return str(exception), 400
|
||||
except dacite.exceptions.MissingValueError as exception:
|
||||
return str(exception), 400
|
||||
|
||||
# NOTE: Autocast disabled to speed up inference, previous inference time was 10s on T4
|
||||
with torch.inference_mode() and torch.cuda.amp.autocast(enabled=False):
|
||||
response = compute_request(
|
||||
inputs=inputs,
|
||||
pipeline=self._pipeline,
|
||||
seed_images_dir=self._seed_images_dir,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def postprocess(self, request: T.Dict) -> T.Dict:
|
||||
"""
|
||||
Incorporate post-processing required by the model if desired here.
|
||||
"""
|
||||
return request
|
|
@ -0,0 +1,158 @@
|
|||
"""
|
||||
Prediction interface for Cog ⚙️
|
||||
https://github.com/replicate/cog/blob/main/docs/python.md
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from cog import BaseModel, BasePredictor, Input, Path
|
||||
|
||||
from riffusion.datatypes import InferenceInput, PromptInput
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
MODEL_ID = "riffusion/riffusion-model-v1"
|
||||
MODEL_CACHE = "riffusion-cache"
|
||||
|
||||
# Where built-in seed images are stored
|
||||
SEED_IMAGES_DIR = Path("./seed_images")
|
||||
SEED_IMAGES = [val.split(".")[0] for val in os.listdir(SEED_IMAGES_DIR) if "png" in val]
|
||||
SEED_IMAGES.sort()
|
||||
|
||||
|
||||
class Output(BaseModel):
|
||||
"""
|
||||
Output class for riffusion predictions
|
||||
"""
|
||||
|
||||
audio: Path
|
||||
spectrogram: Path
|
||||
error: T.Optional[str] = None
|
||||
|
||||
|
||||
class RiffusionPredictor(BasePredictor):
|
||||
"""
|
||||
Implementation of cog predictor object s.t. we can run riffusion predictions w/cog.
|
||||
|
||||
See README & https://github.com/replicate/cog for details
|
||||
"""
|
||||
|
||||
def setup(self, local_files_only=True):
|
||||
"""
|
||||
Loads the model onto GPU from local cache.
|
||||
"""
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
self.model = RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=MODEL_ID,
|
||||
use_traced_unet=True,
|
||||
device=self.device,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=MODEL_CACHE,
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
prompt_a: str = Input(description="The prompt for your audio", default="funky synth solo"),
|
||||
denoising: float = Input(
|
||||
description="How much to transform input spectrogram",
|
||||
default=0.75,
|
||||
ge=0,
|
||||
le=1,
|
||||
),
|
||||
prompt_b: str = Input(
|
||||
description="The second prompt to interpolate with the first,"
|
||||
"leave blank if no interpolation",
|
||||
default=None,
|
||||
),
|
||||
alpha: float = Input(
|
||||
description="Interpolation alpha if using two prompts."
|
||||
"A value of 0 uses prompt_a fully, a value of 1 uses prompt_b fully",
|
||||
default=0.5,
|
||||
ge=0,
|
||||
le=1,
|
||||
),
|
||||
num_inference_steps: int = Input(
|
||||
description="Number of steps to run the diffusion model", default=50, ge=1
|
||||
),
|
||||
seed_image_id: str = Input(
|
||||
description="Seed spectrogram to use", default="vibes", choices=SEED_IMAGES
|
||||
),
|
||||
) -> Output:
|
||||
"""
|
||||
Runs riffusion inference.
|
||||
"""
|
||||
# Load the seed image by ID
|
||||
init_image_path = Path(SEED_IMAGES_DIR, f"{seed_image_id}.png")
|
||||
if not init_image_path.is_file():
|
||||
return Output(error=f"Invalid seed image: {seed_image_id}")
|
||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# fake max ints
|
||||
seed_a = np.random.randint(0, 2147483647)
|
||||
seed_b = np.random.randint(0, 2147483647)
|
||||
|
||||
start = PromptInput(prompt=prompt_a, seed=seed_a, denoising=denoising)
|
||||
if not prompt_b: # no transition
|
||||
prompt_b = prompt_a
|
||||
alpha = 0
|
||||
end = PromptInput(prompt=prompt_b, seed=seed_b, denoising=denoising)
|
||||
riffusion_input = InferenceInput(
|
||||
start=start,
|
||||
end=end,
|
||||
alpha=alpha,
|
||||
num_inference_steps=num_inference_steps,
|
||||
seed_image_id=seed_image_id,
|
||||
)
|
||||
|
||||
# Execute the model to get the spectrogram image
|
||||
image = self.model.riffuse(riffusion_input, init_image=init_image, mask_image=None)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
params = SpectrogramParams()
|
||||
converter = SpectrogramImageConverter(params=params, device=self.device)
|
||||
segment = converter.audio_from_spectrogram_image(image)
|
||||
|
||||
if not os.path.exists("out/"):
|
||||
os.mkdir("out")
|
||||
|
||||
out_img_path = "out/spectrogram.jpg"
|
||||
image.save("out/spectrogram.jpg", exif=image.getexif())
|
||||
|
||||
out_wav_path = "out/gen_sound.wav"
|
||||
segment.export(out_wav_path, format="wav")
|
||||
|
||||
return Output(audio=Path(out_wav_path), spectrogram=Path(out_img_path))
|
||||
|
||||
|
||||
# TODO(hayk): Can we get rid of the below functions and incorporate into
|
||||
# RiffusionPipeline.load_checkpoint?
|
||||
|
||||
|
||||
def download_weights():
|
||||
"""
|
||||
Clears local cache & downloads riffusion weights
|
||||
"""
|
||||
if os.path.exists(MODEL_CACHE):
|
||||
shutil.rmtree(MODEL_CACHE)
|
||||
os.makedirs(MODEL_CACHE)
|
||||
|
||||
pred = RiffusionPredictor()
|
||||
pred.setup(local_files_only=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--download_weights", action="store_true", help="Download and cache weights"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.download_weights:
|
||||
download_weights()
|
|
@ -0,0 +1,99 @@
|
|||
[tool.black]
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
|
||||
# Which rules to run
|
||||
select = [
|
||||
# Pyflakes
|
||||
"F",
|
||||
# Pycodestyle
|
||||
"E",
|
||||
"W",
|
||||
# isort
|
||||
"I001"
|
||||
]
|
||||
|
||||
ignore = []
|
||||
|
||||
# Exclude a variety of commonly ignored directories.
|
||||
exclude = [
|
||||
".bzr",
|
||||
".direnv",
|
||||
".eggs",
|
||||
".git",
|
||||
".hg",
|
||||
".mypy_cache",
|
||||
".nox",
|
||||
".pants.d",
|
||||
".ruff_cache",
|
||||
".svn",
|
||||
".tox",
|
||||
".venv",
|
||||
"__pypackages__",
|
||||
"_build",
|
||||
"buck-out",
|
||||
"build",
|
||||
"dist",
|
||||
"node_modules",
|
||||
"venv",
|
||||
]
|
||||
per-file-ignores = {}
|
||||
|
||||
# Allow unused variables when underscore-prefixed.
|
||||
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
# Assume Python 3.10.
|
||||
target-version = "py310"
|
||||
|
||||
[tool.ruff.mccabe]
|
||||
# Unlike Flake8, default to a complexity level of 10.
|
||||
max-complexity = 10
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "argh.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "cog.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "diffusers.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lora_diffusion.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "numpy.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "plotly.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "pydub.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "scipy.fft.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "scipy.io.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "torchaudio.*"
|
||||
ignore_missing_imports = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = "transformers.*"
|
||||
ignore_missing_imports = true
|
|
@ -0,0 +1,21 @@
|
|||
accelerate
|
||||
argh
|
||||
dacite
|
||||
demucs
|
||||
diffusers>=0.9.0
|
||||
flask
|
||||
flask_cors
|
||||
numpy
|
||||
pillow>=9.1.0
|
||||
plotly
|
||||
pydub
|
||||
pysoundfile
|
||||
scipy
|
||||
soundfile
|
||||
sox
|
||||
streamlit>=1.10.0
|
||||
torch
|
||||
torchaudio
|
||||
torchvision
|
||||
transformers
|
||||
git+https://github.com/cloneofsimo/lora.git
|
|
@ -0,0 +1,187 @@
|
|||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
import torch
|
||||
import torchaudio
|
||||
from torchaudio.transforms import Fade
|
||||
|
||||
from riffusion.util import audio_util
|
||||
|
||||
|
||||
def split_audio(
|
||||
segment: pydub.AudioSegment,
|
||||
model_name: str = "htdemucs_6s",
|
||||
extension: str = "wav",
|
||||
jobs: int = 4,
|
||||
device: str = "cuda",
|
||||
) -> T.Dict[str, pydub.AudioSegment]:
|
||||
"""
|
||||
Split audio into stems using demucs.
|
||||
"""
|
||||
tmp_dir = Path(tempfile.mkdtemp(prefix="split_audio_"))
|
||||
|
||||
# Save the audio to a temporary file
|
||||
audio_path = tmp_dir / "audio.mp3"
|
||||
segment.export(audio_path, format="mp3")
|
||||
|
||||
# Assemble command
|
||||
command = [
|
||||
"demucs",
|
||||
str(audio_path),
|
||||
"--name",
|
||||
model_name,
|
||||
"--out",
|
||||
str(tmp_dir),
|
||||
"--jobs",
|
||||
str(jobs),
|
||||
"--device",
|
||||
device if device != "mps" else "cpu",
|
||||
]
|
||||
print(" ".join(command))
|
||||
|
||||
if extension == "mp3":
|
||||
command.append("--mp3")
|
||||
|
||||
# Run demucs
|
||||
subprocess.run(
|
||||
command,
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Load the stems
|
||||
stems = {}
|
||||
for stem_path in tmp_dir.glob(f"{model_name}/audio/*.{extension}"):
|
||||
stem = pydub.AudioSegment.from_file(stem_path)
|
||||
stems[stem_path.stem] = stem
|
||||
|
||||
# Delete tmp dir
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
return stems
|
||||
|
||||
|
||||
class AudioSplitter:
|
||||
"""
|
||||
Split audio into instrument stems like {drums, bass, vocals, etc.}
|
||||
|
||||
NOTE(hayk): This is deprecated as it has inferior performance to the newer hybrid transformer
|
||||
model in the demucs repo. See the function above. Probably just delete this.
|
||||
|
||||
See:
|
||||
https://pytorch.org/audio/main/tutorials/hybrid_demucs_tutorial.html
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
segment_length_s: float = 10.0,
|
||||
overlap_s: float = 0.1,
|
||||
device: str = "cuda",
|
||||
):
|
||||
self.segment_length_s = segment_length_s
|
||||
self.overlap_s = overlap_s
|
||||
self.device = device
|
||||
|
||||
self.model = self.load_model().to(device)
|
||||
|
||||
@staticmethod
|
||||
def load_model(model_path: str = "models/hdemucs_high_trained.pt") -> torchaudio.models.HDemucs:
|
||||
"""
|
||||
Load the trained HDEMUCS pytorch model.
|
||||
"""
|
||||
# NOTE(hayk): The sources are baked into the pretrained model and can't be changed
|
||||
model = torchaudio.models.hdemucs_high(sources=["drums", "bass", "other", "vocals"])
|
||||
|
||||
path = torchaudio.utils.download_asset(model_path)
|
||||
state_dict = torch.load(path)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
return model
|
||||
|
||||
def split(self, audio: pydub.AudioSegment) -> T.Dict[str, pydub.AudioSegment]:
|
||||
"""
|
||||
Split the given audio segment into instrument stems.
|
||||
"""
|
||||
if audio.channels == 1:
|
||||
audio_stereo = audio.set_channels(2)
|
||||
elif audio.channels == 2:
|
||||
audio_stereo = audio
|
||||
else:
|
||||
raise ValueError(f"Audio must be stereo, but got {audio.channels} channels")
|
||||
|
||||
# Get as (samples, channels) float numpy array
|
||||
waveform_np = np.array(audio_stereo.get_array_of_samples())
|
||||
waveform_np = waveform_np.reshape(-1, audio_stereo.channels)
|
||||
waveform_np_float = waveform_np.astype(np.float32)
|
||||
|
||||
# To torch and channels-first
|
||||
waveform = torch.from_numpy(waveform_np_float).to(self.device)
|
||||
waveform = waveform.transpose(1, 0)
|
||||
|
||||
# Normalize
|
||||
ref = waveform.mean(0)
|
||||
waveform = (waveform - ref.mean()) / ref.std()
|
||||
|
||||
# Split
|
||||
sources = self.separate_sources(
|
||||
waveform[None],
|
||||
sample_rate=audio.frame_rate,
|
||||
)[0]
|
||||
|
||||
# De-normalize
|
||||
sources = sources * ref.std() + ref.mean()
|
||||
|
||||
# To numpy
|
||||
sources_np = sources.cpu().numpy().astype(waveform_np.dtype)
|
||||
|
||||
# Convert to pydub
|
||||
stem_segments = [
|
||||
audio_util.audio_from_waveform(waveform, audio.frame_rate) for waveform in sources_np
|
||||
]
|
||||
|
||||
# Convert back to mono if necessary
|
||||
if audio.channels == 1:
|
||||
stem_segments = [stem.set_channels(1) for stem in stem_segments]
|
||||
|
||||
return dict(zip(self.model.sources, stem_segments))
|
||||
|
||||
def separate_sources(
|
||||
self,
|
||||
waveform: torch.Tensor,
|
||||
sample_rate: int = 44100,
|
||||
):
|
||||
"""
|
||||
Apply model to a given waveform in chunks. Use fade and overlap to smooth the edges.
|
||||
"""
|
||||
batch, channels, length = waveform.shape
|
||||
|
||||
chunk_len = int(sample_rate * self.segment_length_s * (1 + self.overlap_s))
|
||||
start = 0
|
||||
end = chunk_len
|
||||
overlap_frames = self.overlap_s * sample_rate
|
||||
fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape="linear")
|
||||
|
||||
final = torch.zeros(batch, len(self.model.sources), channels, length, device=self.device)
|
||||
|
||||
# TODO(hayk): Improve this code, which came from the torchaudio docs
|
||||
while start < length - overlap_frames:
|
||||
chunk = waveform[:, :, start:end]
|
||||
with torch.no_grad():
|
||||
out = self.model.forward(chunk)
|
||||
out = fade(out)
|
||||
final[:, :, :, start:end] += out
|
||||
if start == 0:
|
||||
fade.fade_in_len = int(overlap_frames)
|
||||
start += int(chunk_len - overlap_frames)
|
||||
else:
|
||||
start += chunk_len
|
||||
end += chunk_len
|
||||
if end >= length:
|
||||
fade.fade_out_len = 0
|
||||
|
||||
return final
|
|
@ -0,0 +1,138 @@
|
|||
"""
|
||||
Command line tools for riffusion.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import argh
|
||||
import numpy as np
|
||||
import pydub
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import image_util
|
||||
|
||||
|
||||
@argh.arg("--step-size-ms", help="Duration of one pixel in the X axis of the spectrogram image")
|
||||
@argh.arg("--num-frequencies", help="Number of Y axes in the spectrogram image")
|
||||
def audio_to_image(
|
||||
*,
|
||||
audio: str,
|
||||
image: str,
|
||||
step_size_ms: int = 10,
|
||||
num_frequencies: int = 512,
|
||||
min_frequency: int = 0,
|
||||
max_frequency: int = 10000,
|
||||
window_duration_ms: int = 100,
|
||||
padded_duration_ms: int = 400,
|
||||
power_for_image: float = 0.25,
|
||||
stereo: bool = False,
|
||||
device: str = "cuda",
|
||||
):
|
||||
"""
|
||||
Compute a spectrogram image from a waveform.
|
||||
"""
|
||||
segment = pydub.AudioSegment.from_file(audio)
|
||||
|
||||
params = SpectrogramParams(
|
||||
sample_rate=segment.frame_rate,
|
||||
stereo=stereo,
|
||||
window_duration_ms=window_duration_ms,
|
||||
padded_duration_ms=padded_duration_ms,
|
||||
step_size_ms=step_size_ms,
|
||||
min_frequency=min_frequency,
|
||||
max_frequency=max_frequency,
|
||||
num_frequencies=num_frequencies,
|
||||
power_for_image=power_for_image,
|
||||
)
|
||||
|
||||
converter = SpectrogramImageConverter(params=params, device=device)
|
||||
|
||||
pil_image = converter.spectrogram_image_from_audio(segment)
|
||||
|
||||
pil_image.save(image, exif=pil_image.getexif(), format="PNG")
|
||||
print(f"Wrote {image}")
|
||||
|
||||
|
||||
def print_exif(*, image: str) -> None:
|
||||
"""
|
||||
Print the params of a spectrogram image as saved in the exif data.
|
||||
"""
|
||||
pil_image = Image.open(image)
|
||||
exif_data = image_util.exif_from_image(pil_image)
|
||||
|
||||
for name, value in exif_data.items():
|
||||
print(f"{name:<20} = {value:>15}")
|
||||
|
||||
|
||||
def image_to_audio(*, image: str, audio: str, device: str = "cuda"):
|
||||
"""
|
||||
Reconstruct an audio clip from a spectrogram image.
|
||||
"""
|
||||
pil_image = Image.open(image)
|
||||
|
||||
# Get parameters from image exif
|
||||
img_exif = pil_image.getexif()
|
||||
assert img_exif is not None
|
||||
|
||||
try:
|
||||
params = SpectrogramParams.from_exif(exif=img_exif)
|
||||
except (KeyError, AttributeError):
|
||||
print("WARNING: Could not find spectrogram parameters in exif data. Using defaults.")
|
||||
params = SpectrogramParams()
|
||||
|
||||
converter = SpectrogramImageConverter(params=params, device=device)
|
||||
segment = converter.audio_from_spectrogram_image(pil_image)
|
||||
|
||||
extension = Path(audio).suffix[1:]
|
||||
segment.export(audio, format=extension)
|
||||
|
||||
print(f"Wrote {audio} ({segment.duration_seconds:.2f} seconds)")
|
||||
|
||||
|
||||
def sample_clips(
|
||||
*,
|
||||
audio: str,
|
||||
output_dir: str,
|
||||
num_clips: int = 1,
|
||||
duration_ms: int = 5000,
|
||||
mono: bool = False,
|
||||
extension: str = "wav",
|
||||
seed: int = -1,
|
||||
):
|
||||
"""
|
||||
Slice an audio file into clips of the given duration.
|
||||
"""
|
||||
if seed >= 0:
|
||||
np.random.seed(seed)
|
||||
|
||||
segment = pydub.AudioSegment.from_file(audio)
|
||||
|
||||
if mono:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
output_dir_path = Path(output_dir)
|
||||
if not output_dir_path.exists():
|
||||
output_dir_path.mkdir(parents=True)
|
||||
|
||||
segment_duration_ms = int(segment.duration_seconds * 1000)
|
||||
for i in range(num_clips):
|
||||
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
||||
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
|
||||
|
||||
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms.{extension}"
|
||||
clip_path = output_dir_path / clip_name
|
||||
clip.export(clip_path, format=extension)
|
||||
print(f"Wrote {clip_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
argh.dispatch_commands(
|
||||
[
|
||||
audio_to_image,
|
||||
image_to_audio,
|
||||
sample_clips,
|
||||
print_exif,
|
||||
]
|
||||
)
|
|
@ -0,0 +1,73 @@
|
|||
"""
|
||||
Data model for the riffusion API.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PromptInput:
|
||||
"""
|
||||
Parameters for one end of interpolation.
|
||||
"""
|
||||
|
||||
# Text prompt fed into a CLIP model
|
||||
prompt: str
|
||||
|
||||
# Random seed for denoising
|
||||
seed: int
|
||||
|
||||
# Negative prompt to avoid (optional)
|
||||
negative_prompt: T.Optional[str] = None
|
||||
|
||||
# Denoising strength
|
||||
denoising: float = 0.75
|
||||
|
||||
# Classifier-free guidance strength
|
||||
guidance: float = 7.0
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InferenceInput:
|
||||
"""
|
||||
Parameters for a single run of the riffusion model, interpolating between
|
||||
a start and end set of PromptInputs. This is the API required for a request
|
||||
to the model server.
|
||||
"""
|
||||
|
||||
# Start point of interpolation
|
||||
start: PromptInput
|
||||
|
||||
# End point of interpolation
|
||||
end: PromptInput
|
||||
|
||||
# Interpolation alpha [0, 1]. A value of 0 uses start fully, a value of 1
|
||||
# uses end fully.
|
||||
alpha: float
|
||||
|
||||
# Number of inner loops of the diffusion model
|
||||
num_inference_steps: int = 50
|
||||
|
||||
# Which seed image to use
|
||||
seed_image_id: str = "og_beat"
|
||||
|
||||
# ID of mask image to use
|
||||
mask_image_id: T.Optional[str] = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InferenceOutput:
|
||||
"""
|
||||
Response from the model inference server.
|
||||
"""
|
||||
|
||||
# base64 encoded spectrogram image as a JPEG
|
||||
image: str
|
||||
|
||||
# base64 encoded audio clip as an MP3
|
||||
audio: str
|
||||
|
||||
# The duration of the audio clip
|
||||
duration_s: float
|
|
@ -0,0 +1,3 @@
|
|||
# external
|
||||
|
||||
This package contains scripts and tools from external sources.
|
|
@ -0,0 +1,24 @@
|
|||
export MODEL_NAME="riffusion/riffusion-model-v1"
|
||||
export INSTANCE_DIR="/tmp/sample_clips_tdlcqdfi/images"
|
||||
export OUTPUT_DIR="/home/ubuntu/lora_dreambooth_waterfalls_2k"
|
||||
|
||||
accelerate launch\
|
||||
--num_machines 1 \
|
||||
--num_processes 8 \
|
||||
--dynamo_backend=no \
|
||||
--mixed_precision="fp16" \
|
||||
riffusion/external/lora/train_lora_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--instance_prompt="style of sks" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=2000
|
||||
|
||||
# TODO try mixed_precision=fp16
|
||||
# TODO try num_processes = 8
|
|
@ -0,0 +1,49 @@
|
|||
from lora_diffusion.cli_lora_pti import train
|
||||
from lora_diffusion.dataset import STYLE_TEMPLATE
|
||||
|
||||
MODEL_NAME = "riffusion/riffusion-model-v1"
|
||||
INSTANCE_DIR = "/tmp/sample_clips_xzv8p57g/images"
|
||||
OUTPUT_DIR = "./lora_output_acoustic"
|
||||
|
||||
if __name__ == "__main__":
|
||||
entries = [
|
||||
"music in the style of {}",
|
||||
"sound in the style of {}",
|
||||
"vibe in the style of {}",
|
||||
"audio in the style of {}",
|
||||
"groove in the style of {}",
|
||||
]
|
||||
for i in range(len(STYLE_TEMPLATE)):
|
||||
STYLE_TEMPLATE[i] = entries[i % len(entries)]
|
||||
print(STYLE_TEMPLATE)
|
||||
|
||||
train(
|
||||
pretrained_model_name_or_path=MODEL_NAME,
|
||||
instance_data_dir=INSTANCE_DIR,
|
||||
output_dir=OUTPUT_DIR,
|
||||
train_text_encoder=True,
|
||||
resolution=512,
|
||||
train_batch_size=1,
|
||||
gradient_accumulation_steps=4,
|
||||
scale_lr=True,
|
||||
learning_rate_unet=1e-4,
|
||||
learning_rate_text=1e-5,
|
||||
learning_rate_ti=5e-4,
|
||||
color_jitter=False,
|
||||
lr_scheduler="linear",
|
||||
lr_warmup_steps=0,
|
||||
placeholder_tokens="<s1>|<s2>",
|
||||
use_template="style",
|
||||
save_steps=100,
|
||||
max_train_steps_ti=1000,
|
||||
max_train_steps_tuning=1000,
|
||||
perform_inversion=True,
|
||||
clip_ti_decay=True,
|
||||
weight_decay_ti=0.000,
|
||||
weight_decay_lora=0.001,
|
||||
continue_inversion=True,
|
||||
continue_inversion_lr=1e-4,
|
||||
device="cuda:0",
|
||||
lora_rank=1,
|
||||
use_face_segmentation_condition=False,
|
||||
)
|
|
@ -0,0 +1,37 @@
|
|||
export MODEL_NAME="riffusion/riffusion-model-v1"
|
||||
export INSTANCE_DIR="/tmp/sample_clips_xzv8p57g/images"
|
||||
export OUTPUT_DIR="./lora_output_acoustic"
|
||||
|
||||
lora_pti \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--train_text_encoder \
|
||||
--resolution=512 \
|
||||
# Started as 1
|
||||
--train_batch_size=4 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--scale_lr \
|
||||
--learning_rate_unet=1e-4 \
|
||||
--learning_rate_text=1e-5 \
|
||||
--learning_rate_ti=5e-4 \
|
||||
# --color_jitter \
|
||||
--lr_scheduler="linear" \
|
||||
--lr_warmup_steps=0 \
|
||||
--placeholder_tokens="<s>" \
|
||||
# initializer tokens
|
||||
# class prompt
|
||||
# --use_template="style"\
|
||||
--save_steps=100 \
|
||||
--max_train_steps_ti=1000 \
|
||||
--max_train_steps_tuning=1000 \
|
||||
--perform_inversion=True \
|
||||
--clip_ti_decay \
|
||||
--weight_decay_ti=0.000 \
|
||||
--weight_decay_lora=0.001\
|
||||
--continue_inversion \
|
||||
--continue_inversion_lr=1e-4 \
|
||||
--device="cuda:0" \
|
||||
# 1 or 4?
|
||||
--lora_rank=4 \
|
||||
# --use_face_segmentation_condition\
|
|
@ -0,0 +1,958 @@
|
|||
# Bootstrapped from:
|
||||
# https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py
|
||||
|
||||
# ruff: noqa
|
||||
# mypy: ignore-errors
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import inspect
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.optimization import get_scheduler
|
||||
from lora_diffusion import (
|
||||
extract_lora_ups_down,
|
||||
inject_trainable_lora,
|
||||
safetensors_available,
|
||||
save_lora_weight,
|
||||
save_safeloras,
|
||||
)
|
||||
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
"""
|
||||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_root,
|
||||
instance_prompt,
|
||||
tokenizer,
|
||||
class_data_root=None,
|
||||
class_prompt=None,
|
||||
size=512,
|
||||
center_crop=False,
|
||||
color_jitter=False,
|
||||
h_flip=False,
|
||||
resize=False,
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.tokenizer = tokenizer
|
||||
self.resize = resize
|
||||
|
||||
self.instance_data_root = Path(instance_data_root)
|
||||
if not self.instance_data_root.exists():
|
||||
raise ValueError("Instance images root doesn't exists.")
|
||||
|
||||
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
||||
self.num_instance_images = len(self.instance_images_path)
|
||||
self.instance_prompt = instance_prompt
|
||||
self._length = self.num_instance_images
|
||||
|
||||
if class_data_root is not None:
|
||||
self.class_data_root = Path(class_data_root)
|
||||
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
||||
self.class_images_path = list(self.class_data_root.iterdir())
|
||||
self.num_class_images = len(self.class_images_path)
|
||||
self._length = max(self.num_class_images, self.num_instance_images)
|
||||
self.class_prompt = class_prompt
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
img_transforms = []
|
||||
|
||||
if resize:
|
||||
img_transforms.append(
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
)
|
||||
if center_crop:
|
||||
img_transforms.append(transforms.CenterCrop(size))
|
||||
if color_jitter:
|
||||
img_transforms.append(transforms.ColorJitter(0.2, 0.1))
|
||||
if h_flip:
|
||||
img_transforms.append(transforms.RandomHorizontalFlip())
|
||||
|
||||
self.image_transforms = transforms.Compose(
|
||||
[*img_transforms, transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
example["instance_prompt_ids"] = self.tokenizer(
|
||||
self.instance_prompt,
|
||||
padding="do_not_pad",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids
|
||||
|
||||
if self.class_data_root:
|
||||
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
||||
if not class_image.mode == "RGB":
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_prompt_ids"] = self.tokenizer(
|
||||
self.class_prompt,
|
||||
padding="do_not_pad",
|
||||
truncation=True,
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
).input_ids
|
||||
|
||||
return example
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
||||
|
||||
def __init__(self, prompt, num_samples):
|
||||
self.prompt = prompt
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
example["prompt"] = self.prompt
|
||||
example["index"] = index
|
||||
return example
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_vae_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained vae or vae identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="A folder containing the training data of class images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The prompt with identifier specifying the instance",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt to specify images in the same class as provided instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to add prior preservation loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_loss_weight",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The weight of prior preservation loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_class_images",
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Minimal class images for prior preservation loss. If not have enough images, additional images will be"
|
||||
" sampled with class_prompt."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="text-inversion-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_format",
|
||||
type=str,
|
||||
choices=["pt", "safe", "both"],
|
||||
default="both",
|
||||
help="The output format of the model predicitions and checkpoints.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
action="store_true",
|
||||
help="Whether to center crop images before resizing to resolution",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--color_jitter",
|
||||
action="store_true",
|
||||
help="Whether to apply color jitter to images",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
help="Whether to train the text encoder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size (per device) for the training dataloader.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_batch_size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Batch size (per device) for sampling images.",
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Save checkpoint every X updates steps.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Rank of LoRA approximation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate_text",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Initial learning rate for text encoder (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help="Number of steps for the warmup in the lr scheduler.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam",
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit Adam from bitsandbytes.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_beta1",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="The beta1 parameter for the Adam optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_beta2",
|
||||
type=float,
|
||||
default=0.999,
|
||||
help="The beta2 parameter for the Adam optimizer.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
default=1e-08,
|
||||
help="Epsilon value for the Adam optimizer",
|
||||
)
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--push_to_hub",
|
||||
action="store_true",
|
||||
help="Whether or not to push the model to the Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hub_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The token to use to push to the Model Hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--local_rank",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="For distributed training: local_rank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_unet",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("File path for unet lora to resume training."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_text_encoder",
|
||||
type=str,
|
||||
default=None,
|
||||
help=("File path for text encoder lora to resume training."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resize",
|
||||
type=bool,
|
||||
default=True,
|
||||
required=False,
|
||||
help="Should images be resized to --resolution before training?",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_xformers", action="store_true", help="Whether or not to use xformers"
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.with_prior_preservation:
|
||||
if args.class_data_dir is None:
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if args.class_prompt is None:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
else:
|
||||
if args.class_data_dir is not None:
|
||||
logger.warning("You need not use --class_data_dir without --with_prior_preservation.")
|
||||
if args.class_prompt is not None:
|
||||
logger.warning("You need not use --class_prompt without --with_prior_preservation.")
|
||||
|
||||
if not safetensors_available:
|
||||
if args.output_format == "both":
|
||||
print(
|
||||
"Safetensors is not available - changing output format to just output PyTorch files"
|
||||
)
|
||||
args.output_format = "pt"
|
||||
elif args.output_format == "safe":
|
||||
raise ValueError(
|
||||
"Safetensors is not available - either install it, or change output_format."
|
||||
)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with="tensorboard",
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
if (
|
||||
args.train_text_encoder
|
||||
and args.gradient_accumulation_steps > 1
|
||||
and accelerator.num_processes > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_images_dir = Path(args.class_data_dir)
|
||||
if not class_images_dir.exists():
|
||||
class_images_dir.mkdir(parents=True)
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
num_new_images = args.num_class_images - cur_class_images
|
||||
logger.info(f"Number of class images to sample: {num_new_images}.")
|
||||
|
||||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(
|
||||
sample_dataset, batch_size=args.sample_batch_size
|
||||
)
|
||||
|
||||
sample_dataloader = accelerator.prepare(sample_dataloader)
|
||||
pipeline.to(accelerator.device)
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader,
|
||||
desc="Generating class images",
|
||||
disable=not accelerator.is_local_main_process,
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
image_filename = (
|
||||
class_images_dir
|
||||
/ f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
)
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load the tokenizer
|
||||
if args.tokenizer_name:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.tokenizer_name,
|
||||
revision=args.revision,
|
||||
)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
# Load models and create wrapper for stable diffusion
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_vae_name_or_path or args.pretrained_model_name_or_path,
|
||||
subfolder=None if args.pretrained_vae_name_or_path else "vae",
|
||||
revision=None if args.pretrained_vae_name_or_path else args.revision,
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.revision,
|
||||
)
|
||||
unet.requires_grad_(False)
|
||||
unet_lora_params, _ = inject_trainable_lora(unet, r=args.lora_rank, loras=args.resume_unet)
|
||||
|
||||
for _up, _down in extract_lora_ups_down(unet):
|
||||
print("Before training: Unet First Layer lora up", _up.weight.data)
|
||||
print("Before training: Unet First Layer lora down", _down.weight.data)
|
||||
break
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder_lora_params, _ = inject_trainable_lora(
|
||||
text_encoder,
|
||||
target_replace_module=["CLIPAttention"],
|
||||
r=args.lora_rank,
|
||||
)
|
||||
for _up, _down in extract_lora_ups_down(
|
||||
text_encoder, target_replace_module=["CLIPAttention"]
|
||||
):
|
||||
print("Before training: text encoder First Layer lora up", _up.weight.data)
|
||||
print("Before training: text encoder First Layer lora down", _down.weight.data)
|
||||
break
|
||||
|
||||
if args.use_xformers:
|
||||
set_use_memory_efficient_attention_xformers(unet, True)
|
||||
set_use_memory_efficient_attention_xformers(vae, True)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate
|
||||
* args.gradient_accumulation_steps
|
||||
* args.train_batch_size
|
||||
* accelerator.num_processes
|
||||
)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
text_lr = args.learning_rate if args.learning_rate_text is None else args.learning_rate_text
|
||||
|
||||
params_to_optimize = (
|
||||
[
|
||||
{"params": itertools.chain(*unet_lora_params), "lr": args.learning_rate},
|
||||
{
|
||||
"params": itertools.chain(*text_encoder_lora_params),
|
||||
"lr": text_lr,
|
||||
},
|
||||
]
|
||||
if args.train_text_encoder
|
||||
else itertools.chain(*unet_lora_params)
|
||||
)
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_config(
|
||||
args.pretrained_model_name_or_path, subfolder="scheduler"
|
||||
)
|
||||
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_prompt=args.class_prompt,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
color_jitter=args.color_jitter,
|
||||
resize=args.resize,
|
||||
)
|
||||
|
||||
def collate_fn(examples):
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
if args.with_prior_preservation:
|
||||
input_ids += [example["class_prompt_ids"] for example in examples]
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
input_ids = tokenizer.pad(
|
||||
{"input_ids": input_ids},
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
return batch
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
if args.train_text_encoder:
|
||||
(
|
||||
unet,
|
||||
text_encoder,
|
||||
optimizer,
|
||||
train_dataloader,
|
||||
lr_scheduler,
|
||||
) = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader, lr_scheduler)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth", config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = (
|
||||
args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
)
|
||||
|
||||
print("***** Running training *****")
|
||||
print(f" Num examples = {len(train_dataset)}")
|
||||
print(f" Num batches each epoch = {len(train_dataloader)}")
|
||||
print(f" Num Epochs = {args.num_train_epochs}")
|
||||
print(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
print(
|
||||
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
||||
)
|
||||
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
print(f" Total optimization steps = {args.max_train_steps}")
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
last_save = 0
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0,
|
||||
noise_scheduler.config.num_train_timesteps,
|
||||
(bsz,),
|
||||
device=latents.device,
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
||||
)
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
loss = (
|
||||
F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
.mean([1, 2, 3])
|
||||
.mean()
|
||||
)
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(
|
||||
model_pred_prior.float(), target_prior.float(), reduction="mean"
|
||||
)
|
||||
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet.parameters()
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
progress_bar.update(1)
|
||||
optimizer.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.save_steps and global_step - last_save >= args.save_steps:
|
||||
if accelerator.is_main_process:
|
||||
# newer versions of accelerate allow the 'keep_fp32_wrapper' arg. without passing
|
||||
# it, the models will be unwrapped, and when they are then used for further training,
|
||||
# we will crash. pass this, but only to newer versions of accelerate. fixes
|
||||
# https://github.com/huggingface/diffusers/issues/1566
|
||||
accepts_keep_fp32_wrapper = "keep_fp32_wrapper" in set(
|
||||
inspect.signature(accelerator.unwrap_model).parameters.keys()
|
||||
)
|
||||
extra_args = (
|
||||
{"keep_fp32_wrapper": True} if accepts_keep_fp32_wrapper else {}
|
||||
)
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet, **extra_args),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder, **extra_args),
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
filename_unet = f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.pt"
|
||||
filename_text_encoder = (
|
||||
f"{args.output_dir}/lora_weight_e{epoch}_s{global_step}.text_encoder.pt"
|
||||
)
|
||||
print(f"save weights {filename_unet}, {filename_text_encoder}")
|
||||
save_lora_weight(pipeline.unet, filename_unet)
|
||||
if args.train_text_encoder:
|
||||
save_lora_weight(
|
||||
pipeline.text_encoder,
|
||||
filename_text_encoder,
|
||||
target_replace_module=["CLIPAttention"],
|
||||
)
|
||||
|
||||
for _up, _down in extract_lora_ups_down(pipeline.unet):
|
||||
print(
|
||||
"First Unet Layer's Up Weight is now : ",
|
||||
_up.weight.data,
|
||||
)
|
||||
print(
|
||||
"First Unet Layer's Down Weight is now : ",
|
||||
_down.weight.data,
|
||||
)
|
||||
break
|
||||
if args.train_text_encoder:
|
||||
for _up, _down in extract_lora_ups_down(
|
||||
pipeline.text_encoder,
|
||||
target_replace_module=["CLIPAttention"],
|
||||
):
|
||||
print(
|
||||
"First Text Encoder Layer's Up Weight is now : ",
|
||||
_up.weight.data,
|
||||
)
|
||||
print(
|
||||
"First Text Encoder Layer's Down Weight is now : ",
|
||||
_down.weight.data,
|
||||
)
|
||||
break
|
||||
|
||||
last_save = global_step
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
print("\n\nLora TRAINING DONE!\n\n")
|
||||
|
||||
if args.output_format == "pt" or args.output_format == "both":
|
||||
save_lora_weight(pipeline.unet, args.output_dir + "/lora_weight.pt")
|
||||
if args.train_text_encoder:
|
||||
save_lora_weight(
|
||||
pipeline.text_encoder,
|
||||
args.output_dir + "/lora_weight.text_encoder.pt",
|
||||
target_replace_module=["CLIPAttention"],
|
||||
)
|
||||
|
||||
if args.output_format == "safe" or args.output_format == "both":
|
||||
loras = {}
|
||||
loras["unet"] = (pipeline.unet, {"CrossAttention", "Attention", "GEGLU"})
|
||||
if args.train_text_encoder:
|
||||
loras["text_encoder"] = (pipeline.text_encoder, {"CLIPAttention"})
|
||||
|
||||
save_safeloras(loras, args.output_dir + "/lora_weight.safetensors")
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(
|
||||
commit_message="End of training",
|
||||
blocking=False,
|
||||
auto_lfs_prune=True,
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,372 @@
|
|||
"""
|
||||
This code is taken from the diffusers community pipeline:
|
||||
|
||||
https://github.com/huggingface/diffusers/blob/f242eba4fdc5b76dc40d3a9c01ba49b2c74b9796/examples/community/lpw_stable_diffusion.py
|
||||
|
||||
License: Apache 2.0
|
||||
"""
|
||||
# ruff: noqa
|
||||
# mypy: ignore-errors
|
||||
|
||||
import logging
|
||||
import re
|
||||
import typing as T
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
re_attention = re.compile(
|
||||
r"""
|
||||
\\\(|
|
||||
\\\)|
|
||||
\\\[|
|
||||
\\]|
|
||||
\\\\|
|
||||
\\|
|
||||
\(|
|
||||
\[|
|
||||
:([+-]?[.\d]+)\)|
|
||||
\)|
|
||||
]|
|
||||
[^\\()\[\]:]+|
|
||||
:
|
||||
""",
|
||||
re.X,
|
||||
)
|
||||
|
||||
|
||||
def parse_prompt_attention(text):
|
||||
"""
|
||||
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
||||
Accepted tokens are:
|
||||
(abc) - increases attention to abc by a multiplier of 1.1
|
||||
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
||||
[abc] - decreases attention to abc by a multiplier of 1.1
|
||||
\( - literal character '('
|
||||
\[ - literal character '['
|
||||
\) - literal character ')'
|
||||
\] - literal character ']'
|
||||
\\ - literal character '\'
|
||||
anything else - just text
|
||||
>>> parse_prompt_attention('normal text')
|
||||
[['normal text', 1.0]]
|
||||
>>> parse_prompt_attention('an (important) word')
|
||||
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
||||
>>> parse_prompt_attention('(unbalanced')
|
||||
[['unbalanced', 1.1]]
|
||||
>>> parse_prompt_attention('\(literal\]')
|
||||
[['(literal]', 1.0]]
|
||||
>>> parse_prompt_attention('(unnecessary)(parens)')
|
||||
[['unnecessaryparens', 1.1]]
|
||||
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
||||
[['a ', 1.0],
|
||||
['house', 1.5730000000000004],
|
||||
[' ', 1.1],
|
||||
['on', 1.0],
|
||||
[' a ', 1.1],
|
||||
['hill', 0.55],
|
||||
[', sun, ', 1.1],
|
||||
['sky', 1.4641000000000006],
|
||||
['.', 1.1]]
|
||||
"""
|
||||
|
||||
res = []
|
||||
round_brackets = []
|
||||
square_brackets = []
|
||||
|
||||
round_bracket_multiplier = 1.1
|
||||
square_bracket_multiplier = 1 / 1.1
|
||||
|
||||
def multiply_range(start_position, multiplier):
|
||||
for p in range(start_position, len(res)):
|
||||
res[p][1] *= multiplier
|
||||
|
||||
for m in re_attention.finditer(text):
|
||||
text = m.group(0)
|
||||
weight = m.group(1)
|
||||
|
||||
if text.startswith("\\"):
|
||||
res.append([text[1:], 1.0])
|
||||
elif text == "(":
|
||||
round_brackets.append(len(res))
|
||||
elif text == "[":
|
||||
square_brackets.append(len(res))
|
||||
elif weight is not None and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), float(weight))
|
||||
elif text == ")" and len(round_brackets) > 0:
|
||||
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
||||
elif text == "]" and len(square_brackets) > 0:
|
||||
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
||||
else:
|
||||
res.append([text, 1.0])
|
||||
|
||||
for pos in round_brackets:
|
||||
multiply_range(pos, round_bracket_multiplier)
|
||||
|
||||
for pos in square_brackets:
|
||||
multiply_range(pos, square_bracket_multiplier)
|
||||
|
||||
if len(res) == 0:
|
||||
res = [["", 1.0]]
|
||||
|
||||
# merge runs of identical weights
|
||||
i = 0
|
||||
while i + 1 < len(res):
|
||||
if res[i][1] == res[i + 1][1]:
|
||||
res[i][0] += res[i + 1][0]
|
||||
res.pop(i + 1)
|
||||
else:
|
||||
i += 1
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: T.List[str], max_length: int):
|
||||
r"""
|
||||
Tokenize a list of prompts and return its tokens with weights of each token.
|
||||
No padding, starting or ending token is included.
|
||||
"""
|
||||
tokens = []
|
||||
weights = []
|
||||
truncated = False
|
||||
for text in prompt:
|
||||
texts_and_weights = parse_prompt_attention(text)
|
||||
text_token = []
|
||||
text_weight = []
|
||||
for word, weight in texts_and_weights:
|
||||
# tokenize and discard the starting and the ending token
|
||||
token = pipe.tokenizer(word).input_ids[1:-1]
|
||||
text_token += token
|
||||
# copy the weight by length of token
|
||||
text_weight += [weight] * len(token)
|
||||
# stop if the text is too long (longer than truncation limit)
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
break
|
||||
# truncate
|
||||
if len(text_token) > max_length:
|
||||
truncated = True
|
||||
text_token = text_token[:max_length]
|
||||
text_weight = text_weight[:max_length]
|
||||
tokens.append(text_token)
|
||||
weights.append(text_weight)
|
||||
if truncated:
|
||||
logger.warning(
|
||||
"Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples"
|
||||
)
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def pad_tokens_and_weights(
|
||||
tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77
|
||||
):
|
||||
r"""
|
||||
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
||||
"""
|
||||
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
||||
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
||||
for i in range(len(tokens)):
|
||||
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
||||
if no_boseos_middle:
|
||||
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
||||
else:
|
||||
w = []
|
||||
if len(weights[i]) == 0:
|
||||
w = [1.0] * weights_length
|
||||
else:
|
||||
for j in range(max_embeddings_multiples):
|
||||
w.append(1.0) # weight for starting token in this chunk
|
||||
w += weights[i][
|
||||
j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))
|
||||
]
|
||||
w.append(1.0) # weight for ending token in this chunk
|
||||
w += [1.0] * (weights_length - len(w))
|
||||
weights[i] = w[:]
|
||||
|
||||
return tokens, weights
|
||||
|
||||
|
||||
def get_unweighted_text_embeddings(
|
||||
pipe: StableDiffusionPipeline,
|
||||
text_input: torch.Tensor,
|
||||
chunk_length: int,
|
||||
no_boseos_middle: T.Optional[bool] = True,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
When the length of tokens is a multiple of the capacity of the text encoder,
|
||||
it should be split into chunks and sent to the text encoder individually.
|
||||
"""
|
||||
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
||||
if max_embeddings_multiples > 1:
|
||||
text_embeddings = []
|
||||
for i in range(max_embeddings_multiples):
|
||||
# extract the i-th chunk
|
||||
text_input_chunk = text_input[
|
||||
:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2
|
||||
].clone()
|
||||
|
||||
# cover the head and the tail by the starting and the ending tokens
|
||||
text_input_chunk[:, 0] = text_input[0, 0]
|
||||
text_input_chunk[:, -1] = text_input[0, -1]
|
||||
text_embedding = pipe.text_encoder(text_input_chunk)[0]
|
||||
|
||||
if no_boseos_middle:
|
||||
if i == 0:
|
||||
# discard the ending token
|
||||
text_embedding = text_embedding[:, :-1]
|
||||
elif i == max_embeddings_multiples - 1:
|
||||
# discard the starting token
|
||||
text_embedding = text_embedding[:, 1:]
|
||||
else:
|
||||
# discard both starting and ending tokens
|
||||
text_embedding = text_embedding[:, 1:-1]
|
||||
|
||||
text_embeddings.append(text_embedding)
|
||||
text_embeddings = torch.concat(text_embeddings, axis=1)
|
||||
else:
|
||||
text_embeddings = pipe.text_encoder(text_input)[0]
|
||||
return text_embeddings
|
||||
|
||||
|
||||
def get_weighted_text_embeddings(
|
||||
pipe: StableDiffusionPipeline,
|
||||
prompt: T.Union[str, T.List[str]],
|
||||
uncond_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
|
||||
max_embeddings_multiples: T.Optional[int] = 3,
|
||||
no_boseos_middle: T.Optional[bool] = False,
|
||||
skip_parsing: T.Optional[bool] = False,
|
||||
skip_weighting: T.Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> T.Tuple[torch.FloatTensor, T.Optional[torch.FloatTensor]]:
|
||||
r"""
|
||||
Prompts can be assigned with local weights using brackets. For example,
|
||||
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
||||
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
||||
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
||||
Args:
|
||||
pipe (`StableDiffusionPipeline`):
|
||||
Pipe to provide access to the tokenizer and the text encoder.
|
||||
prompt (`str` or `T.List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
uncond_prompt (`str` or `T.List[str]`):
|
||||
The unconditional prompt or prompts for guide the image generation. If unconditional prompt
|
||||
is provided, the embeddings of prompt and uncond_prompt are concatenated.
|
||||
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
||||
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
||||
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
||||
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
||||
ending token in each of the chunk in the middle.
|
||||
skip_parsing (`bool`, *optional*, defaults to `False`):
|
||||
Skip the parsing of brackets.
|
||||
skip_weighting (`bool`, *optional*, defaults to `False`):
|
||||
Skip the weighting. When the parsing is skipped, it is forced True.
|
||||
"""
|
||||
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
if not skip_parsing:
|
||||
prompt_tokens, prompt_weights = get_prompts_with_weights(pipe, prompt, max_length - 2)
|
||||
|
||||
if uncond_prompt is not None:
|
||||
if isinstance(uncond_prompt, str):
|
||||
uncond_prompt = [uncond_prompt]
|
||||
uncond_tokens, uncond_weights = get_prompts_with_weights(
|
||||
pipe, uncond_prompt, max_length - 2
|
||||
)
|
||||
else:
|
||||
prompt_tokens = [
|
||||
token[1:-1]
|
||||
for token in pipe.tokenizer(prompt, max_length=max_length, truncation=True).input_ids
|
||||
]
|
||||
prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
|
||||
if uncond_prompt is not None:
|
||||
if isinstance(uncond_prompt, str):
|
||||
uncond_prompt = [uncond_prompt]
|
||||
uncond_tokens = [
|
||||
token[1:-1]
|
||||
for token in pipe.tokenizer(
|
||||
uncond_prompt, max_length=max_length, truncation=True
|
||||
).input_ids
|
||||
]
|
||||
uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
|
||||
|
||||
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
||||
max_length = max([len(token) for token in prompt_tokens])
|
||||
if uncond_prompt is not None:
|
||||
max_length = max(max_length, max([len(token) for token in uncond_tokens]))
|
||||
|
||||
max_embeddings_multiples = min(
|
||||
max_embeddings_multiples,
|
||||
(max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
|
||||
)
|
||||
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
||||
max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
||||
|
||||
# pad the length of tokens and weights
|
||||
bos = pipe.tokenizer.bos_token_id
|
||||
eos = pipe.tokenizer.eos_token_id
|
||||
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
||||
prompt_tokens,
|
||||
prompt_weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=pipe.tokenizer.model_max_length,
|
||||
)
|
||||
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
|
||||
if uncond_prompt is not None:
|
||||
uncond_tokens, uncond_weights = pad_tokens_and_weights(
|
||||
uncond_tokens,
|
||||
uncond_weights,
|
||||
max_length,
|
||||
bos,
|
||||
eos,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
chunk_length=pipe.tokenizer.model_max_length,
|
||||
)
|
||||
uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
|
||||
|
||||
# get the embeddings
|
||||
text_embeddings = get_unweighted_text_embeddings(
|
||||
pipe,
|
||||
prompt_tokens,
|
||||
pipe.tokenizer.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
|
||||
if uncond_prompt is not None:
|
||||
uncond_embeddings = get_unweighted_text_embeddings(
|
||||
pipe,
|
||||
uncond_tokens,
|
||||
pipe.tokenizer.model_max_length,
|
||||
no_boseos_middle=no_boseos_middle,
|
||||
)
|
||||
uncond_weights = torch.tensor(
|
||||
uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device
|
||||
)
|
||||
|
||||
# assign weights to the prompts and normalize in the sense of mean
|
||||
# TODO: should we normalize by chunk or in a whole (current implementation)?
|
||||
if (not skip_parsing) and (not skip_weighting):
|
||||
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= prompt_weights.unsqueeze(-1)
|
||||
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
||||
text_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
if uncond_prompt is not None:
|
||||
previous_mean = (
|
||||
uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
)
|
||||
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
||||
current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(uncond_embeddings.dtype)
|
||||
uncond_embeddings *= (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
||||
|
||||
if uncond_prompt is not None:
|
||||
return text_embeddings, uncond_embeddings
|
||||
return text_embeddings, None
|
|
@ -0,0 +1,477 @@
|
|||
"""
|
||||
Riffusion inference pipeline.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils import logging
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from riffusion.datatypes import InferenceInput
|
||||
from riffusion.external.prompt_weighting import get_weighted_text_embeddings
|
||||
from riffusion.util import torch_util
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class RiffusionPipeline(DiffusionPipeline):
|
||||
"""
|
||||
Diffusers pipeline for doing a controlled img2img interpolation for audio generation.
|
||||
|
||||
# TODO(hayk): Document more
|
||||
|
||||
Part of this code was adapted from the non-img2img interpolation pipeline at:
|
||||
|
||||
https://github.com/huggingface/diffusers/blob/main/examples/community/interpolate_stable_diffusion.py
|
||||
|
||||
Check the documentation for DiffusionPipeline for full information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: T.Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
unet=unet,
|
||||
scheduler=scheduler,
|
||||
safety_checker=safety_checker,
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_checkpoint(
|
||||
cls,
|
||||
checkpoint: str,
|
||||
use_traced_unet: bool = True,
|
||||
channels_last: bool = False,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
device: str = "cuda",
|
||||
local_files_only: bool = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
cache_dir: T.Optional[str] = None,
|
||||
) -> RiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion model pipeline.
|
||||
|
||||
Args:
|
||||
checkpoint: Model checkpoint on disk in diffusers format
|
||||
use_traced_unet: Whether to use the traced unet for speedups
|
||||
device: Device to load the model on
|
||||
channels_last: Whether to use channels_last memory format
|
||||
local_files_only: Don't download, only use local files
|
||||
low_cpu_mem_usage: Attempt to use less memory on CPU
|
||||
"""
|
||||
device = torch_util.check_device(device)
|
||||
|
||||
if device == "cpu" or device.lower().startswith("mps"):
|
||||
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
||||
dtype = torch.float32
|
||||
|
||||
pipeline = RiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=dtype,
|
||||
# Disable the NSFW filter, causes incorrect false positives
|
||||
# TODO(hayk): Disable the "you have passed a non-standard module" warning from this.
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
).to(device)
|
||||
|
||||
if channels_last:
|
||||
pipeline.unet.to(memory_format=torch.channels_last)
|
||||
|
||||
# Optionally load a traced unet
|
||||
if checkpoint == "riffusion/riffusion-model-v1" and use_traced_unet:
|
||||
traced_unet = cls.load_traced_unet(
|
||||
checkpoint=checkpoint,
|
||||
subfolder="unet_traced",
|
||||
filename="unet_traced.pt",
|
||||
in_channels=pipeline.unet.in_channels,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
|
||||
if traced_unet is not None:
|
||||
pipeline.unet = traced_unet
|
||||
|
||||
model = pipeline.to(device)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def load_traced_unet(
|
||||
checkpoint: str,
|
||||
subfolder: str,
|
||||
filename: str,
|
||||
in_channels: int,
|
||||
dtype: torch.dtype,
|
||||
device: str = "cuda",
|
||||
local_files_only=False,
|
||||
cache_dir: T.Optional[str] = None,
|
||||
) -> T.Optional[torch.nn.Module]:
|
||||
"""
|
||||
Load a traced unet from the huggingface hub. This can improve performance.
|
||||
"""
|
||||
if device == "cpu" or device.lower().startswith("mps"):
|
||||
print("WARNING: Traced UNet only available for CUDA, skipping")
|
||||
return None
|
||||
|
||||
# Download and load the traced unet
|
||||
unet_file = hf_hub_download(
|
||||
checkpoint,
|
||||
subfolder=subfolder,
|
||||
filename=filename,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
unet_traced = torch.jit.load(unet_file)
|
||||
|
||||
# Wrap it in a torch module
|
||||
class TracedUNet(torch.nn.Module):
|
||||
@dataclasses.dataclass
|
||||
class UNet2DConditionOutput:
|
||||
sample: torch.FloatTensor
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = device
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
return self.UNet2DConditionOutput(sample=sample)
|
||||
|
||||
return TracedUNet()
|
||||
|
||||
@property
|
||||
def device(self) -> str:
|
||||
return str(self.vae.device)
|
||||
|
||||
@functools.lru_cache()
|
||||
def embed_text(self, text) -> torch.FloatTensor:
|
||||
"""
|
||||
Takes in text and turns it into text embeddings.
|
||||
"""
|
||||
text_input = self.tokenizer(
|
||||
text,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
with torch.no_grad():
|
||||
embed = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
return embed
|
||||
|
||||
@functools.lru_cache()
|
||||
def embed_text_weighted(self, text) -> torch.FloatTensor:
|
||||
"""
|
||||
Get text embedding with weights.
|
||||
"""
|
||||
return get_weighted_text_embeddings(
|
||||
pipe=self,
|
||||
prompt=text,
|
||||
uncond_prompt=None,
|
||||
max_embeddings_multiples=3,
|
||||
no_boseos_middle=False,
|
||||
skip_parsing=False,
|
||||
skip_weighting=False,
|
||||
)[0]
|
||||
|
||||
@torch.no_grad()
|
||||
def riffuse(
|
||||
self,
|
||||
inputs: InferenceInput,
|
||||
init_image: Image.Image,
|
||||
mask_image: T.Optional[Image.Image] = None,
|
||||
use_reweighting: bool = True,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Runs inference using interpolation with both img2img and text conditioning.
|
||||
|
||||
Args:
|
||||
inputs: Parameter dataclass
|
||||
init_image: Image used for conditioning
|
||||
mask_image: White pixels in the mask will be replaced by noise and therefore repainted,
|
||||
while black pixels will be preserved. It will be converted to a single
|
||||
channel (luminance) before use.
|
||||
use_reweighting: Use prompt reweighting
|
||||
"""
|
||||
alpha = inputs.alpha
|
||||
start = inputs.start
|
||||
end = inputs.end
|
||||
|
||||
guidance_scale = start.guidance * (1.0 - alpha) + end.guidance * alpha
|
||||
|
||||
# TODO(hayk): Always generate the seed on CPU?
|
||||
if self.device.lower().startswith("mps"):
|
||||
generator_start = torch.Generator(device="cpu").manual_seed(start.seed)
|
||||
generator_end = torch.Generator(device="cpu").manual_seed(end.seed)
|
||||
else:
|
||||
generator_start = torch.Generator(device=self.device).manual_seed(start.seed)
|
||||
generator_end = torch.Generator(device=self.device).manual_seed(end.seed)
|
||||
|
||||
# Text encodings
|
||||
if use_reweighting:
|
||||
embed_start = self.embed_text_weighted(start.prompt)
|
||||
embed_end = self.embed_text_weighted(end.prompt)
|
||||
else:
|
||||
embed_start = self.embed_text(start.prompt)
|
||||
embed_end = self.embed_text(end.prompt)
|
||||
|
||||
text_embedding = embed_start + alpha * (embed_end - embed_start)
|
||||
|
||||
# Image latents
|
||||
init_image_torch = preprocess_image(init_image).to(
|
||||
device=self.device, dtype=embed_start.dtype
|
||||
)
|
||||
init_latent_dist = self.vae.encode(init_image_torch).latent_dist
|
||||
# TODO(hayk): Probably this seed should just be 0 always? Make it 100% symmetric. The
|
||||
# result is so close no matter the seed that it doesn't really add variety.
|
||||
if self.device.lower().startswith("mps"):
|
||||
generator = torch.Generator(device="cpu").manual_seed(start.seed)
|
||||
else:
|
||||
generator = torch.Generator(device=self.device).manual_seed(start.seed)
|
||||
|
||||
init_latents = init_latent_dist.sample(generator=generator)
|
||||
init_latents = 0.18215 * init_latents
|
||||
|
||||
# Prepare mask latent
|
||||
mask: T.Optional[torch.Tensor] = None
|
||||
if mask_image:
|
||||
vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
mask = preprocess_mask(mask_image, scale_factor=vae_scale_factor).to(
|
||||
device=self.device, dtype=embed_start.dtype
|
||||
)
|
||||
|
||||
outputs = self.interpolate_img2img(
|
||||
text_embeddings=text_embedding,
|
||||
init_latents=init_latents,
|
||||
mask=mask,
|
||||
generator_a=generator_start,
|
||||
generator_b=generator_end,
|
||||
interpolate_alpha=alpha,
|
||||
strength_a=start.denoising,
|
||||
strength_b=end.denoising,
|
||||
num_inference_steps=inputs.num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
)
|
||||
|
||||
return outputs["images"][0]
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate_img2img(
|
||||
self,
|
||||
text_embeddings: torch.Tensor,
|
||||
init_latents: torch.Tensor,
|
||||
generator_a: torch.Generator,
|
||||
generator_b: torch.Generator,
|
||||
interpolate_alpha: float,
|
||||
mask: T.Optional[torch.Tensor] = None,
|
||||
strength_a: float = 0.8,
|
||||
strength_b: float = 0.8,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: T.Optional[T.Union[str, T.List[str]]] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
eta: T.Optional[float] = 0.0,
|
||||
output_type: T.Optional[str] = "pil",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
TODO
|
||||
"""
|
||||
batch_size = text_embeddings.shape[0]
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
bs_embed, seq_len, _ = text_embeddings.shape
|
||||
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
||||
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
uncond_tokens = [""]
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError("The length of `negative_prompt` should be equal to batch_size.")
|
||||
else:
|
||||
uncond_tokens = negative_prompt
|
||||
|
||||
# max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt
|
||||
uncond_embeddings = uncond_embeddings.repeat_interleave(
|
||||
batch_size * num_images_per_prompt, dim=0
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
|
||||
latents_dtype = text_embeddings.dtype
|
||||
|
||||
strength = (1 - interpolate_alpha) * strength_a + interpolate_alpha * strength_b
|
||||
|
||||
# get the original timestep using init_timestep
|
||||
offset = self.scheduler.config.get("steps_offset", 0)
|
||||
init_timestep = int(num_inference_steps * strength) + offset
|
||||
init_timestep = min(init_timestep, num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps[-init_timestep]
|
||||
timesteps = torch.tensor(
|
||||
[timesteps] * batch_size * num_images_per_prompt, device=self.device
|
||||
)
|
||||
|
||||
# add noise to latents using the timesteps
|
||||
noise_a = torch.randn(
|
||||
init_latents.shape, generator=generator_a, device=self.device, dtype=latents_dtype
|
||||
)
|
||||
noise_b = torch.randn(
|
||||
init_latents.shape, generator=generator_b, device=self.device, dtype=latents_dtype
|
||||
)
|
||||
noise = torch_util.slerp(interpolate_alpha, noise_a, noise_b)
|
||||
init_latents_orig = init_latents
|
||||
init_latents = self.scheduler.add_noise(init_latents, noise, timesteps)
|
||||
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same args
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
latents = init_latents.clone()
|
||||
|
||||
t_start = max(num_inference_steps - init_timestep + offset, 0)
|
||||
|
||||
# Some schedulers like PNDM have timesteps as arrays
|
||||
# It's more optimized to move all timesteps to correct device beforehand
|
||||
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input, t, encoder_hidden_states=text_embeddings
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if mask is not None:
|
||||
init_latents_proper = self.scheduler.add_noise(
|
||||
init_latents_orig, noise, torch.tensor([t])
|
||||
)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
latents = (init_latents_proper * mask) + (latents * (1 - mask))
|
||||
|
||||
latents = 1.0 / 0.18215 * latents
|
||||
image = self.vae.decode(latents).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
return dict(images=image, latents=latents, nsfw_content_detected=False)
|
||||
|
||||
|
||||
def preprocess_image(image: Image.Image) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess an image for the model.
|
||||
"""
|
||||
w, h = image.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
|
||||
image = image.resize((w, h), resample=Image.LANCZOS)
|
||||
|
||||
image_np = np.array(image).astype(np.float32) / 255.0
|
||||
image_np = image_np[None].transpose(0, 3, 1, 2)
|
||||
|
||||
image_torch = torch.from_numpy(image_np)
|
||||
|
||||
return 2.0 * image_torch - 1.0
|
||||
|
||||
|
||||
def preprocess_mask(mask: Image.Image, scale_factor: int = 8) -> torch.Tensor:
|
||||
"""
|
||||
Preprocess a mask for the model.
|
||||
"""
|
||||
# Convert to grayscale
|
||||
mask = mask.convert("L")
|
||||
|
||||
# Resize to integer multiple of 32
|
||||
w, h = mask.size
|
||||
w, h = map(lambda x: x - x % 32, (w, h))
|
||||
mask = mask.resize((w // scale_factor, h // scale_factor), resample=Image.NEAREST)
|
||||
|
||||
# Convert to numpy array and rescale
|
||||
mask_np = np.array(mask).astype(np.float32) / 255.0
|
||||
|
||||
# Tile and transpose
|
||||
mask_np = np.tile(mask_np, (4, 1, 1))
|
||||
mask_np = mask_np[None].transpose(0, 1, 2, 3) # what does this step do?
|
||||
|
||||
# Invert to repaint white and keep black
|
||||
mask_np = 1 - mask_np # repaint white, keep black
|
||||
|
||||
return torch.from_numpy(mask_np)
|
|
@ -0,0 +1,189 @@
|
|||
"""
|
||||
Flask server that serves the riffusion model as an API.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import dacite
|
||||
import flask
|
||||
import PIL
|
||||
from flask_cors import CORS
|
||||
|
||||
from riffusion.datatypes import InferenceInput, InferenceOutput
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import base64_util
|
||||
|
||||
# Flask app with CORS
|
||||
app = flask.Flask(__name__)
|
||||
CORS(app)
|
||||
|
||||
# Log at the INFO level to both stdout and disk
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logging.getLogger().addHandler(logging.FileHandler("server.log"))
|
||||
|
||||
# Global variable for the model pipeline
|
||||
PIPELINE: T.Optional[RiffusionPipeline] = None
|
||||
|
||||
# Where built-in seed images are stored
|
||||
SEED_IMAGES_DIR = Path(Path(__file__).resolve().parent.parent, "seed_images")
|
||||
|
||||
|
||||
def run_app(
|
||||
*,
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
no_traced_unet: bool = False,
|
||||
device: str = "cuda",
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 3013,
|
||||
debug: bool = False,
|
||||
ssl_certificate: T.Optional[str] = None,
|
||||
ssl_key: T.Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Run a flask API that serves the given riffusion model checkpoint.
|
||||
"""
|
||||
# Initialize the model
|
||||
global PIPELINE
|
||||
PIPELINE = RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=checkpoint,
|
||||
use_traced_unet=not no_traced_unet,
|
||||
device=device,
|
||||
)
|
||||
|
||||
args = dict(
|
||||
debug=debug,
|
||||
threaded=False,
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
|
||||
if ssl_certificate:
|
||||
assert ssl_key is not None
|
||||
args["ssl_context"] = (ssl_certificate, ssl_key)
|
||||
|
||||
app.run(**args) # type: ignore
|
||||
|
||||
|
||||
@app.route("/run_inference/", methods=["POST"])
|
||||
def run_inference():
|
||||
"""
|
||||
Execute the riffusion model as an API.
|
||||
|
||||
Inputs:
|
||||
Serialized JSON of the InferenceInput dataclass
|
||||
|
||||
Returns:
|
||||
Serialized JSON of the InferenceOutput dataclass
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Parse the payload as JSON
|
||||
json_data = json.loads(flask.request.data)
|
||||
|
||||
# Log the request
|
||||
logging.info(json_data)
|
||||
|
||||
# Parse an InferenceInput dataclass from the payload
|
||||
try:
|
||||
inputs = dacite.from_dict(InferenceInput, json_data)
|
||||
except dacite.exceptions.WrongTypeError as exception:
|
||||
logging.info(json_data)
|
||||
return str(exception), 400
|
||||
except dacite.exceptions.MissingValueError as exception:
|
||||
logging.info(json_data)
|
||||
return str(exception), 400
|
||||
|
||||
response = compute_request(
|
||||
inputs=inputs,
|
||||
seed_images_dir=SEED_IMAGES_DIR,
|
||||
pipeline=PIPELINE,
|
||||
)
|
||||
|
||||
# Log the total time
|
||||
logging.info(f"Request took {time.time() - start_time:.2f} s")
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def compute_request(
|
||||
inputs: InferenceInput,
|
||||
pipeline: RiffusionPipeline,
|
||||
seed_images_dir: str,
|
||||
) -> T.Union[str, T.Tuple[str, int]]:
|
||||
"""
|
||||
Does all the heavy lifting of the request.
|
||||
|
||||
Args:
|
||||
inputs: The input dataclass
|
||||
pipeline: The riffusion model pipeline
|
||||
seed_images_dir: The directory where seed images are stored
|
||||
"""
|
||||
# Load the seed image by ID
|
||||
init_image_path = Path(seed_images_dir, f"{inputs.seed_image_id}.png")
|
||||
|
||||
if not init_image_path.is_file():
|
||||
return f"Invalid seed image: {inputs.seed_image_id}", 400
|
||||
init_image = PIL.Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# Load the mask image by ID
|
||||
mask_image: T.Optional[PIL.Image.Image] = None
|
||||
if inputs.mask_image_id:
|
||||
mask_image_path = Path(seed_images_dir, f"{inputs.mask_image_id}.png")
|
||||
if not mask_image_path.is_file():
|
||||
return f"Invalid mask image: {inputs.mask_image_id}", 400
|
||||
mask_image = PIL.Image.open(str(mask_image_path)).convert("RGB")
|
||||
|
||||
# Execute the model to get the spectrogram image
|
||||
image = pipeline.riffuse(
|
||||
inputs,
|
||||
init_image=init_image,
|
||||
mask_image=mask_image,
|
||||
)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
params = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
# Reconstruct audio from the image
|
||||
# TODO(hayk): It may help performance a bit to cache this object
|
||||
converter = SpectrogramImageConverter(params=params, device=str(pipeline.device))
|
||||
|
||||
segment = converter.audio_from_spectrogram_image(
|
||||
image,
|
||||
apply_filters=True,
|
||||
)
|
||||
|
||||
# Export audio to MP3 bytes
|
||||
mp3_bytes = io.BytesIO()
|
||||
segment.export(mp3_bytes, format="mp3")
|
||||
mp3_bytes.seek(0)
|
||||
|
||||
# Export image to JPEG bytes
|
||||
image_bytes = io.BytesIO()
|
||||
image.save(image_bytes, exif=image.getexif(), format="JPEG")
|
||||
image_bytes.seek(0)
|
||||
|
||||
# Assemble the output dataclass
|
||||
output = InferenceOutput(
|
||||
image="data:image/jpeg;base64," + base64_util.encode(image_bytes),
|
||||
audio="data:audio/mpeg;base64," + base64_util.encode(mp3_bytes),
|
||||
duration_s=segment.duration_seconds,
|
||||
)
|
||||
|
||||
return json.dumps(dataclasses.asdict(output))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argh
|
||||
|
||||
argh.dispatch_command(run_app)
|
|
@ -0,0 +1,204 @@
|
|||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import audio_util, torch_util
|
||||
|
||||
|
||||
class SpectrogramConverter:
|
||||
"""
|
||||
Convert between audio segments and spectrogram tensors using torchaudio.
|
||||
|
||||
In this class a "spectrogram" is defined as a (batch, time, frequency) tensor with float values
|
||||
that represent the amplitude of the frequency at that time bucket (in the frequency domain).
|
||||
Frequencies are given in the perceptul Mel scale defined by the params. A more specific term
|
||||
used in some functions is "mel amplitudes".
|
||||
|
||||
The spectrogram computed from `spectrogram_from_audio` is complex valued, but it only
|
||||
returns the amplitude, because the phase is chaotic and hard to learn. The function
|
||||
`audio_from_spectrogram` is an approximate inverse of `spectrogram_from_audio`, which
|
||||
approximates the phase information using the Griffin-Lim algorithm.
|
||||
|
||||
Each channel in the audio is treated independently, and the spectrogram has a batch dimension
|
||||
equal to the number of channels in the input audio segment.
|
||||
|
||||
Both the Griffin Lim algorithm and the Mel scaling process are lossy.
|
||||
|
||||
For more information, see https://pytorch.org/audio/stable/transforms.html
|
||||
"""
|
||||
|
||||
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
|
||||
self.p = params
|
||||
|
||||
self.device = torch_util.check_device(device)
|
||||
|
||||
if device.lower().startswith("mps"):
|
||||
warnings.warn(
|
||||
"WARNING: MPS does not support audio operations, falling back to CPU for them",
|
||||
stacklevel=2,
|
||||
)
|
||||
self.device = "cpu"
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.Spectrogram.html
|
||||
self.spectrogram_func = torchaudio.transforms.Spectrogram(
|
||||
n_fft=params.n_fft,
|
||||
hop_length=params.hop_length,
|
||||
win_length=params.win_length,
|
||||
pad=0,
|
||||
window_fn=torch.hann_window,
|
||||
power=None,
|
||||
normalized=False,
|
||||
wkwargs=None,
|
||||
center=True,
|
||||
pad_mode="reflect",
|
||||
onesided=True,
|
||||
).to(self.device)
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.GriffinLim.html
|
||||
self.inverse_spectrogram_func = torchaudio.transforms.GriffinLim(
|
||||
n_fft=params.n_fft,
|
||||
n_iter=params.num_griffin_lim_iters,
|
||||
win_length=params.win_length,
|
||||
hop_length=params.hop_length,
|
||||
window_fn=torch.hann_window,
|
||||
power=1.0,
|
||||
wkwargs=None,
|
||||
momentum=0.99,
|
||||
length=None,
|
||||
rand_init=True,
|
||||
).to(self.device)
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.MelScale.html
|
||||
self.mel_scaler = torchaudio.transforms.MelScale(
|
||||
n_mels=params.num_frequencies,
|
||||
sample_rate=params.sample_rate,
|
||||
f_min=params.min_frequency,
|
||||
f_max=params.max_frequency,
|
||||
n_stft=params.n_fft // 2 + 1,
|
||||
norm=params.mel_scale_norm,
|
||||
mel_scale=params.mel_scale_type,
|
||||
).to(self.device)
|
||||
|
||||
# https://pytorch.org/audio/stable/generated/torchaudio.transforms.InverseMelScale.html
|
||||
self.inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
|
||||
n_stft=params.n_fft // 2 + 1,
|
||||
n_mels=params.num_frequencies,
|
||||
sample_rate=params.sample_rate,
|
||||
f_min=params.min_frequency,
|
||||
f_max=params.max_frequency,
|
||||
max_iter=params.max_mel_iters,
|
||||
tolerance_loss=1e-5,
|
||||
tolerance_change=1e-8,
|
||||
sgdargs=None,
|
||||
norm=params.mel_scale_norm,
|
||||
mel_scale=params.mel_scale_type,
|
||||
).to(self.device)
|
||||
|
||||
def spectrogram_from_audio(
|
||||
self,
|
||||
audio: pydub.AudioSegment,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute a spectrogram from an audio segment.
|
||||
|
||||
Args:
|
||||
audio: Audio segment which must match the sample rate of the params
|
||||
|
||||
Returns:
|
||||
spectrogram: (channel, frequency, time)
|
||||
"""
|
||||
assert int(audio.frame_rate) == self.p.sample_rate, "Audio sample rate must match params"
|
||||
|
||||
# Get the samples as a numpy array in (batch, samples) shape
|
||||
waveform = np.array([c.get_array_of_samples() for c in audio.split_to_mono()])
|
||||
|
||||
# Convert to floats if necessary
|
||||
if waveform.dtype != np.float32:
|
||||
waveform = waveform.astype(np.float32)
|
||||
|
||||
waveform_tensor = torch.from_numpy(waveform).to(self.device)
|
||||
amplitudes_mel = self.mel_amplitudes_from_waveform(waveform_tensor)
|
||||
return amplitudes_mel.cpu().numpy()
|
||||
|
||||
def audio_from_spectrogram(
|
||||
self,
|
||||
spectrogram: np.ndarray,
|
||||
apply_filters: bool = True,
|
||||
) -> pydub.AudioSegment:
|
||||
"""
|
||||
Reconstruct an audio segment from a spectrogram.
|
||||
|
||||
Args:
|
||||
spectrogram: (batch, frequency, time)
|
||||
apply_filters: Post-process with normalization and compression
|
||||
|
||||
Returns:
|
||||
audio: Audio segment with channels equal to the batch dimension
|
||||
"""
|
||||
# Move to device
|
||||
amplitudes_mel = torch.from_numpy(spectrogram).to(self.device)
|
||||
|
||||
# Reconstruct the waveform
|
||||
waveform = self.waveform_from_mel_amplitudes(amplitudes_mel)
|
||||
|
||||
# Convert to audio segment
|
||||
segment = audio_util.audio_from_waveform(
|
||||
samples=waveform.cpu().numpy(),
|
||||
sample_rate=self.p.sample_rate,
|
||||
# Normalize the waveform to the range [-1, 1]
|
||||
normalize=True,
|
||||
)
|
||||
|
||||
# Optionally apply post-processing filters
|
||||
if apply_filters:
|
||||
segment = audio_util.apply_filters(
|
||||
segment,
|
||||
compression=False,
|
||||
)
|
||||
|
||||
return segment
|
||||
|
||||
def mel_amplitudes_from_waveform(
|
||||
self,
|
||||
waveform: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Torch-only function to compute Mel-scale amplitudes from a waveform.
|
||||
|
||||
Args:
|
||||
waveform: (batch, samples)
|
||||
|
||||
Returns:
|
||||
amplitudes_mel: (batch, frequency, time)
|
||||
"""
|
||||
# Compute the complex-valued spectrogram
|
||||
spectrogram_complex = self.spectrogram_func(waveform)
|
||||
|
||||
# Take the magnitude
|
||||
amplitudes = torch.abs(spectrogram_complex)
|
||||
|
||||
# Convert to mel scale
|
||||
return self.mel_scaler(amplitudes)
|
||||
|
||||
def waveform_from_mel_amplitudes(
|
||||
self,
|
||||
amplitudes_mel: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Torch-only function to approximately reconstruct a waveform from Mel-scale amplitudes.
|
||||
|
||||
Args:
|
||||
amplitudes_mel: (batch, frequency, time)
|
||||
|
||||
Returns:
|
||||
waveform: (batch, samples)
|
||||
"""
|
||||
# Convert from mel scale to linear
|
||||
amplitudes_linear = self.inverse_mel_scaler(amplitudes_mel)
|
||||
|
||||
# Run the approximate algorithm to compute the phase and recover the waveform
|
||||
return self.inverse_spectrogram_func(amplitudes_linear)
|
|
@ -0,0 +1,91 @@
|
|||
import numpy as np
|
||||
import pydub
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.spectrogram_converter import SpectrogramConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import image_util
|
||||
|
||||
|
||||
class SpectrogramImageConverter:
|
||||
"""
|
||||
Convert between spectrogram images and audio segments.
|
||||
|
||||
This is a wrapper around SpectrogramConverter that additionally converts from spectrograms
|
||||
to images and back. The real audio processing lives in SpectrogramConverter.
|
||||
"""
|
||||
|
||||
def __init__(self, params: SpectrogramParams, device: str = "cuda"):
|
||||
self.p = params
|
||||
self.device = device
|
||||
self.converter = SpectrogramConverter(params=params, device=device)
|
||||
|
||||
def spectrogram_image_from_audio(
|
||||
self,
|
||||
segment: pydub.AudioSegment,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Compute a spectrogram image from an audio segment.
|
||||
|
||||
Args:
|
||||
segment: Audio segment to convert
|
||||
|
||||
Returns:
|
||||
Spectrogram image (in pillow format)
|
||||
"""
|
||||
assert int(segment.frame_rate) == self.p.sample_rate, "Sample rate mismatch"
|
||||
|
||||
if self.p.stereo:
|
||||
if segment.channels == 1:
|
||||
print("WARNING: Mono audio but stereo=True, cloning channel")
|
||||
segment = segment.set_channels(2)
|
||||
elif segment.channels > 2:
|
||||
print("WARNING: Multi channel audio, reducing to stereo")
|
||||
segment = segment.set_channels(2)
|
||||
else:
|
||||
if segment.channels > 1:
|
||||
print("WARNING: Stereo audio but stereo=False, setting to mono")
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
spectrogram = self.converter.spectrogram_from_audio(segment)
|
||||
|
||||
image = image_util.image_from_spectrogram(
|
||||
spectrogram,
|
||||
power=self.p.power_for_image,
|
||||
)
|
||||
|
||||
# Store conversion params in exif metadata of the image
|
||||
exif_data = self.p.to_exif()
|
||||
exif_data[SpectrogramParams.ExifTags.MAX_VALUE.value] = float(np.max(spectrogram))
|
||||
exif = image.getexif()
|
||||
exif.update(exif_data.items())
|
||||
|
||||
return image
|
||||
|
||||
def audio_from_spectrogram_image(
|
||||
self,
|
||||
image: Image.Image,
|
||||
apply_filters: bool = True,
|
||||
max_value: float = 30e6,
|
||||
) -> pydub.AudioSegment:
|
||||
"""
|
||||
Reconstruct an audio segment from a spectrogram image.
|
||||
|
||||
Args:
|
||||
image: Spectrogram image (in pillow format)
|
||||
apply_filters: Apply post-processing to improve the reconstructed audio
|
||||
max_value: Scaled max amplitude of the spectrogram. Shouldn't matter.
|
||||
"""
|
||||
spectrogram = image_util.spectrogram_from_image(
|
||||
image,
|
||||
max_value=max_value,
|
||||
power=self.p.power_for_image,
|
||||
stereo=self.p.stereo,
|
||||
)
|
||||
|
||||
segment = self.converter.audio_from_spectrogram(
|
||||
spectrogram,
|
||||
apply_filters=apply_filters,
|
||||
)
|
||||
|
||||
return segment
|
|
@ -0,0 +1,115 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SpectrogramParams:
|
||||
"""
|
||||
Parameters for the conversion from audio to spectrograms to images and back.
|
||||
|
||||
Includes helpers to convert to and from EXIF tags, allowing these parameters to be stored
|
||||
within spectrogram images.
|
||||
|
||||
To understand what these parameters do and to customize them, read `spectrogram_converter.py`
|
||||
and the linked torchaudio documentation.
|
||||
"""
|
||||
|
||||
# Whether the audio is stereo or mono
|
||||
stereo: bool = False
|
||||
|
||||
# FFT parameters
|
||||
sample_rate: int = 44100
|
||||
step_size_ms: int = 10
|
||||
window_duration_ms: int = 100
|
||||
padded_duration_ms: int = 400
|
||||
|
||||
# Mel scale parameters
|
||||
num_frequencies: int = 512
|
||||
# TODO(hayk): Set these to [20, 20000] for newer models
|
||||
min_frequency: int = 0
|
||||
max_frequency: int = 10000
|
||||
mel_scale_norm: T.Optional[str] = None
|
||||
mel_scale_type: str = "htk"
|
||||
max_mel_iters: int = 200
|
||||
|
||||
# Griffin Lim parameters
|
||||
num_griffin_lim_iters: int = 32
|
||||
|
||||
# Image parameterization
|
||||
power_for_image: float = 0.25
|
||||
|
||||
class ExifTags(Enum):
|
||||
"""
|
||||
Custom EXIF tags for the spectrogram image.
|
||||
"""
|
||||
|
||||
SAMPLE_RATE = 11000
|
||||
STEREO = 11005
|
||||
STEP_SIZE_MS = 11010
|
||||
WINDOW_DURATION_MS = 11020
|
||||
PADDED_DURATION_MS = 11030
|
||||
|
||||
NUM_FREQUENCIES = 11040
|
||||
MIN_FREQUENCY = 11050
|
||||
MAX_FREQUENCY = 11060
|
||||
|
||||
POWER_FOR_IMAGE = 11070
|
||||
MAX_VALUE = 11080
|
||||
|
||||
@property
|
||||
def n_fft(self) -> int:
|
||||
"""
|
||||
The number of samples in each STFT window, with padding.
|
||||
"""
|
||||
return int(self.padded_duration_ms / 1000.0 * self.sample_rate)
|
||||
|
||||
@property
|
||||
def win_length(self) -> int:
|
||||
"""
|
||||
The number of samples in each STFT window.
|
||||
"""
|
||||
return int(self.window_duration_ms / 1000.0 * self.sample_rate)
|
||||
|
||||
@property
|
||||
def hop_length(self) -> int:
|
||||
"""
|
||||
The number of samples between each STFT window.
|
||||
"""
|
||||
return int(self.step_size_ms / 1000.0 * self.sample_rate)
|
||||
|
||||
def to_exif(self) -> T.Dict[int, T.Any]:
|
||||
"""
|
||||
Return a dictionary of EXIF tags for the current values.
|
||||
"""
|
||||
return {
|
||||
self.ExifTags.SAMPLE_RATE.value: self.sample_rate,
|
||||
self.ExifTags.STEREO.value: self.stereo,
|
||||
self.ExifTags.STEP_SIZE_MS.value: self.step_size_ms,
|
||||
self.ExifTags.WINDOW_DURATION_MS.value: self.window_duration_ms,
|
||||
self.ExifTags.PADDED_DURATION_MS.value: self.padded_duration_ms,
|
||||
self.ExifTags.NUM_FREQUENCIES.value: self.num_frequencies,
|
||||
self.ExifTags.MIN_FREQUENCY.value: self.min_frequency,
|
||||
self.ExifTags.MAX_FREQUENCY.value: self.max_frequency,
|
||||
self.ExifTags.POWER_FOR_IMAGE.value: float(self.power_for_image),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_exif(cls, exif: T.Mapping[int, T.Any]) -> SpectrogramParams:
|
||||
"""
|
||||
Create a SpectrogramParams object from the EXIF tags of the given image.
|
||||
"""
|
||||
# TODO(hayk): Handle missing tags
|
||||
return cls(
|
||||
sample_rate=exif[cls.ExifTags.SAMPLE_RATE.value],
|
||||
stereo=bool(exif[cls.ExifTags.STEREO.value]),
|
||||
step_size_ms=exif[cls.ExifTags.STEP_SIZE_MS.value],
|
||||
window_duration_ms=exif[cls.ExifTags.WINDOW_DURATION_MS.value],
|
||||
padded_duration_ms=exif[cls.ExifTags.PADDED_DURATION_MS.value],
|
||||
num_frequencies=exif[cls.ExifTags.NUM_FREQUENCIES.value],
|
||||
min_frequency=exif[cls.ExifTags.MIN_FREQUENCY.value],
|
||||
max_frequency=exif[cls.ExifTags.MAX_FREQUENCY.value],
|
||||
power_for_image=exif[cls.ExifTags.POWER_FOR_IMAGE.value],
|
||||
)
|
|
@ -0,0 +1,3 @@
|
|||
# streamlit
|
||||
|
||||
This package is an interactive streamlit app for riffusion.
|
|
@ -0,0 +1,404 @@
|
|||
import io
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.datatypes import InferenceInput, PromptInput
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
from riffusion.streamlit.pages.interpolation import get_prompt_inputs, run_interpolation
|
||||
from riffusion.util import audio_util
|
||||
|
||||
|
||||
def render_audio_to_audio() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":wave: Audio to Audio")
|
||||
st.write(
|
||||
"""
|
||||
Modify existing audio from a text prompt or interpolate between two.
|
||||
"""
|
||||
)
|
||||
|
||||
with st.expander("Help", False):
|
||||
st.write(
|
||||
"""
|
||||
This tool allows you to upload an audio file of arbitrary length and modify it with
|
||||
a text prompt. It does this by sweeping over the audio in overlapping clips, doing
|
||||
img2img style transfer with riffusion, then stitching the clips back together with
|
||||
cross fading to eliminate seams.
|
||||
|
||||
Try a denoising strength of 0.4 for light modification and 0.55 for more heavy
|
||||
modification. The best specific denoising depends on how different the prompt is
|
||||
from the source audio. You can play with the seed to get infinite variations.
|
||||
Currently the same seed is used for all clips along the track.
|
||||
|
||||
If the Interpolation check box is enabled, supports entering two sets of prompt,
|
||||
seed, and denoising value and smoothly blends between them along the selected
|
||||
duration of the audio. This is a great way to create a transition.
|
||||
"""
|
||||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
|
||||
use_magic_mix = st.sidebar.checkbox("Use Magic Mix", False)
|
||||
lora_path = st.sidebar.text_input("Lora Path", "")
|
||||
lora_scale = st.sidebar.number_input("Lora Scale", value=1.0)
|
||||
|
||||
with st.sidebar:
|
||||
num_inference_steps = T.cast(
|
||||
int,
|
||||
st.number_input(
|
||||
"Steps per sample", value=50, help="Number of denoising steps per model run"
|
||||
),
|
||||
)
|
||||
|
||||
guidance = st.number_input(
|
||||
"Guidance",
|
||||
value=7.0,
|
||||
help="How much the model listens to the text prompt",
|
||||
)
|
||||
|
||||
scheduler = st.selectbox(
|
||||
"Scheduler",
|
||||
options=streamlit_util.SCHEDULER_OPTIONS,
|
||||
index=0,
|
||||
help="Which diffusion scheduler to use",
|
||||
)
|
||||
assert scheduler is not None
|
||||
|
||||
audio_file = st.file_uploader(
|
||||
"Upload audio",
|
||||
type=streamlit_util.AUDIO_EXTENSIONS,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
|
||||
if not audio_file:
|
||||
st.info("Upload audio to get started")
|
||||
return
|
||||
|
||||
st.write("#### Original")
|
||||
st.audio(audio_file)
|
||||
|
||||
segment = streamlit_util.load_audio_file(audio_file)
|
||||
|
||||
# TODO(hayk): Fix
|
||||
if segment.frame_rate != 44100:
|
||||
st.warning("Audio must be 44100Hz. Converting")
|
||||
segment = segment.set_frame_rate(44100)
|
||||
st.write(f"Duration: {segment.duration_seconds:.2f}s, Sample Rate: {segment.frame_rate}Hz")
|
||||
|
||||
clip_p = get_clip_params()
|
||||
start_time_s = clip_p["start_time_s"]
|
||||
clip_duration_s = clip_p["clip_duration_s"]
|
||||
overlap_duration_s = clip_p["overlap_duration_s"]
|
||||
|
||||
duration_s = min(clip_p["duration_s"], segment.duration_seconds - start_time_s)
|
||||
increment_s = clip_duration_s - overlap_duration_s
|
||||
clip_start_times = start_time_s + np.arange(0, duration_s - clip_duration_s, increment_s)
|
||||
|
||||
write_clip_details(
|
||||
clip_start_times=clip_start_times,
|
||||
clip_duration_s=clip_duration_s,
|
||||
overlap_duration_s=overlap_duration_s,
|
||||
)
|
||||
|
||||
interpolate = st.checkbox(
|
||||
"Interpolate between two endpoints",
|
||||
value=False,
|
||||
help="Interpolate between two prompts, seeds, or denoising values along the"
|
||||
"duration of the segment",
|
||||
)
|
||||
|
||||
counter = streamlit_util.StreamlitCounter()
|
||||
|
||||
with st.form("audio to audio form"):
|
||||
if interpolate:
|
||||
left, right = st.columns(2)
|
||||
|
||||
with left:
|
||||
st.write("##### Prompt A")
|
||||
prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a"))
|
||||
|
||||
with right:
|
||||
st.write("##### Prompt B")
|
||||
prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b"))
|
||||
elif use_magic_mix:
|
||||
prompt = st.text_input("Prompt", key="prompt_a")
|
||||
|
||||
row = st.columns(4)
|
||||
|
||||
seed = T.cast(
|
||||
int,
|
||||
row[0].number_input(
|
||||
"Seed",
|
||||
value=42,
|
||||
key="seed_a",
|
||||
),
|
||||
)
|
||||
prompt_input_a = PromptInput(
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
guidance=guidance,
|
||||
)
|
||||
magic_mix_kmin = row[1].number_input("Kmin", value=0.3)
|
||||
magic_mix_kmax = row[2].number_input("Kmax", value=0.5)
|
||||
magic_mix_mix_factor = row[3].number_input("Mix Factor", value=0.5)
|
||||
else:
|
||||
prompt_input_a = PromptInput(
|
||||
guidance=guidance,
|
||||
**get_prompt_inputs(key="a", include_negative_prompt=True, cols=True),
|
||||
)
|
||||
|
||||
st.form_submit_button("Riff", type="primary", on_click=counter.increment)
|
||||
|
||||
show_clip_details = st.sidebar.checkbox("Show Clip Details", True)
|
||||
show_difference = st.sidebar.checkbox("Show Difference", False)
|
||||
|
||||
clip_segments = slice_audio_into_clips(
|
||||
segment=segment,
|
||||
clip_start_times=clip_start_times,
|
||||
clip_duration_s=clip_duration_s,
|
||||
)
|
||||
|
||||
if not prompt_input_a.prompt:
|
||||
st.info("Enter a prompt")
|
||||
return
|
||||
|
||||
if counter.value == 0:
|
||||
return
|
||||
|
||||
params = SpectrogramParams()
|
||||
|
||||
if interpolate:
|
||||
# TODO(hayk): Make not linspace
|
||||
alphas = list(np.linspace(0, 1, len(clip_segments)))
|
||||
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
|
||||
st.write(f"**Alphas** : [{alphas_str}]")
|
||||
|
||||
result_images: T.List[Image.Image] = []
|
||||
result_segments: T.List[pydub.AudioSegment] = []
|
||||
for i, clip_segment in enumerate(clip_segments):
|
||||
st.write(f"### Clip {i} at {clip_start_times[i]:.2f}s")
|
||||
|
||||
audio_bytes = io.BytesIO()
|
||||
clip_segment.export(audio_bytes, format="wav")
|
||||
|
||||
init_image = streamlit_util.spectrogram_image_from_audio(
|
||||
clip_segment,
|
||||
params=params,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# TODO(hayk): Roll this into spectrogram_image_from_audio?
|
||||
init_image_resized = scale_image_to_32_stride(init_image)
|
||||
|
||||
progress_callback = None
|
||||
if show_clip_details:
|
||||
left, right = st.columns(2)
|
||||
|
||||
left.write("##### Source Clip")
|
||||
left.image(init_image, use_column_width=False)
|
||||
left.audio(audio_bytes)
|
||||
|
||||
right.write("##### Riffed Clip")
|
||||
empty_bin = right.empty()
|
||||
with empty_bin.container():
|
||||
st.info("Riffing...")
|
||||
progress = st.progress(0.0)
|
||||
progress_callback = progress.progress
|
||||
|
||||
if interpolate:
|
||||
assert use_magic_mix is False, "Cannot use magic mix and interpolate together"
|
||||
inputs = InferenceInput(
|
||||
alpha=float(alphas[i]),
|
||||
num_inference_steps=num_inference_steps,
|
||||
seed_image_id="og_beat",
|
||||
start=prompt_input_a,
|
||||
end=prompt_input_b,
|
||||
)
|
||||
|
||||
image, audio_bytes = run_interpolation(
|
||||
inputs=inputs,
|
||||
init_image=init_image_resized,
|
||||
device=device,
|
||||
)
|
||||
elif use_magic_mix:
|
||||
assert not prompt_input_a.negative_prompt, "No negative prompt with magic mix"
|
||||
image = streamlit_util.run_img2img_magic_mix(
|
||||
prompt=prompt_input_a.prompt,
|
||||
init_image=init_image_resized,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance,
|
||||
seed=prompt_input_a.seed,
|
||||
kmin=magic_mix_kmin,
|
||||
kmax=magic_mix_kmax,
|
||||
mix_factor=magic_mix_mix_factor,
|
||||
device=device,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
image = streamlit_util.run_img2img(
|
||||
prompt=prompt_input_a.prompt,
|
||||
init_image=init_image_resized,
|
||||
denoising_strength=prompt_input_a.denoising,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance,
|
||||
negative_prompt=prompt_input_a.negative_prompt,
|
||||
seed=prompt_input_a.seed,
|
||||
progress_callback=progress_callback,
|
||||
device=device,
|
||||
scheduler=scheduler,
|
||||
lora_path=lora_path,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
# Resize back to original size
|
||||
image = image.resize(init_image.size, Image.BICUBIC)
|
||||
|
||||
result_images.append(image)
|
||||
|
||||
if show_clip_details:
|
||||
empty_bin.empty()
|
||||
right.image(image, use_column_width=False)
|
||||
|
||||
riffed_segment = streamlit_util.audio_segment_from_spectrogram_image(
|
||||
image=image,
|
||||
params=params,
|
||||
device=device,
|
||||
)
|
||||
result_segments.append(riffed_segment)
|
||||
|
||||
audio_bytes = io.BytesIO()
|
||||
riffed_segment.export(audio_bytes, format="wav")
|
||||
|
||||
if show_clip_details:
|
||||
right.audio(audio_bytes)
|
||||
|
||||
if show_clip_details and show_difference:
|
||||
diff_np = np.maximum(
|
||||
0, np.asarray(init_image).astype(np.float32) - np.asarray(image).astype(np.float32)
|
||||
)
|
||||
diff_image = Image.fromarray(255 - diff_np.astype(np.uint8))
|
||||
diff_segment = streamlit_util.audio_segment_from_spectrogram_image(
|
||||
image=diff_image,
|
||||
params=params,
|
||||
device=device,
|
||||
)
|
||||
|
||||
audio_bytes = io.BytesIO()
|
||||
diff_segment.export(audio_bytes, format=extension)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
# Combine clips with a crossfade based on overlap
|
||||
combined_segment = audio_util.stitch_segments(result_segments, crossfade_s=overlap_duration_s)
|
||||
|
||||
st.write(f"#### Final Audio ({combined_segment.duration_seconds}s)")
|
||||
|
||||
input_name = Path(audio_file.name).stem
|
||||
output_name = f"{input_name}_{prompt_input_a.prompt.replace(' ', '_')}"
|
||||
streamlit_util.display_and_download_audio(combined_segment, output_name, extension=extension)
|
||||
|
||||
|
||||
def get_clip_params(advanced: bool = False) -> T.Dict[str, T.Any]:
|
||||
"""
|
||||
Render the parameters of slicing audio into clips.
|
||||
"""
|
||||
p: T.Dict[str, T.Any] = {}
|
||||
|
||||
cols = st.columns(4)
|
||||
|
||||
p["start_time_s"] = cols[0].number_input(
|
||||
"Start Time [s]",
|
||||
min_value=0.0,
|
||||
value=0.0,
|
||||
)
|
||||
p["duration_s"] = cols[1].number_input(
|
||||
"Duration [s]",
|
||||
min_value=0.0,
|
||||
value=15.0,
|
||||
)
|
||||
|
||||
if advanced:
|
||||
p["clip_duration_s"] = cols[2].number_input(
|
||||
"Clip Duration [s]",
|
||||
min_value=3.0,
|
||||
max_value=10.0,
|
||||
value=5.0,
|
||||
)
|
||||
else:
|
||||
p["clip_duration_s"] = 5.0
|
||||
|
||||
if advanced:
|
||||
p["overlap_duration_s"] = cols[3].number_input(
|
||||
"Overlap Duration [s]",
|
||||
min_value=0.0,
|
||||
max_value=10.0,
|
||||
value=0.2,
|
||||
)
|
||||
else:
|
||||
p["overlap_duration_s"] = 0.2
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def write_clip_details(
|
||||
clip_start_times: np.ndarray, clip_duration_s: float, overlap_duration_s: float
|
||||
):
|
||||
"""
|
||||
Write details of the clips to be sliced from an audio segment.
|
||||
"""
|
||||
clip_details_text = (
|
||||
f"Slicing {len(clip_start_times)} clips of duration {clip_duration_s}s "
|
||||
f"with overlap {overlap_duration_s}s"
|
||||
)
|
||||
|
||||
with st.expander(clip_details_text):
|
||||
st.dataframe(
|
||||
{
|
||||
"Start Time [s]": clip_start_times,
|
||||
"End Time [s]": clip_start_times + clip_duration_s,
|
||||
"Duration [s]": clip_duration_s,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def slice_audio_into_clips(
|
||||
segment: pydub.AudioSegment, clip_start_times: T.Sequence[float], clip_duration_s: float
|
||||
) -> T.List[pydub.AudioSegment]:
|
||||
"""
|
||||
Slice an audio segment into a list of clips of a given duration at the given start times.
|
||||
"""
|
||||
clip_segments: T.List[pydub.AudioSegment] = []
|
||||
for i, clip_start_time_s in enumerate(clip_start_times):
|
||||
clip_start_time_ms = int(clip_start_time_s * 1000)
|
||||
clip_duration_ms = int(clip_duration_s * 1000)
|
||||
clip_segment = segment[clip_start_time_ms : clip_start_time_ms + clip_duration_ms]
|
||||
|
||||
# TODO(hayk): I don't think this is working properly
|
||||
if i == len(clip_start_times) - 1:
|
||||
silence_ms = clip_duration_ms - int(clip_segment.duration_seconds * 1000)
|
||||
if silence_ms > 0:
|
||||
clip_segment = clip_segment.append(pydub.AudioSegment.silent(duration=silence_ms))
|
||||
|
||||
clip_segments.append(clip_segment)
|
||||
|
||||
return clip_segments
|
||||
|
||||
|
||||
def scale_image_to_32_stride(image: Image.Image) -> Image.Image:
|
||||
"""
|
||||
Scale an image to a size that is a multiple of 32.
|
||||
"""
|
||||
closest_width = int(np.ceil(image.width / 32) * 32)
|
||||
closest_height = int(np.ceil(image.height / 32) * 32)
|
||||
return image.resize((closest_width, closest_height), Image.BICUBIC)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_audio_to_audio()
|
|
@ -0,0 +1,74 @@
|
|||
import dataclasses
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
from riffusion.util.image_util import exif_from_image
|
||||
|
||||
|
||||
def render_image_to_audio() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":musical_keyboard: Image to Audio")
|
||||
st.write(
|
||||
"""
|
||||
Reconstruct audio from spectrogram images.
|
||||
"""
|
||||
)
|
||||
|
||||
with st.expander("Help", False):
|
||||
st.write(
|
||||
"""
|
||||
This tool takes an existing spectrogram image and reconstructs it into an audio
|
||||
waveform. It also displays the EXIF metadata stored inside the image, which can
|
||||
contain the parameters used to create the spectrogram image. If no EXIF is contained,
|
||||
assumes default parameters.
|
||||
"""
|
||||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
|
||||
image_file = st.file_uploader(
|
||||
"Upload a file",
|
||||
type=streamlit_util.IMAGE_EXTENSIONS,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
if not image_file:
|
||||
st.info("Upload an image file to get started")
|
||||
return
|
||||
|
||||
image = Image.open(image_file)
|
||||
st.image(image)
|
||||
|
||||
with st.expander("Image metadata", expanded=False):
|
||||
exif = exif_from_image(image)
|
||||
st.json(exif)
|
||||
|
||||
try:
|
||||
params = SpectrogramParams.from_exif(exif=image.getexif())
|
||||
except KeyError:
|
||||
st.info("Could not find spectrogram parameters in exif data. Using defaults.")
|
||||
params = SpectrogramParams()
|
||||
|
||||
with st.expander("Spectrogram Parameters", expanded=False):
|
||||
st.json(dataclasses.asdict(params))
|
||||
|
||||
segment = streamlit_util.audio_segment_from_spectrogram_image(
|
||||
image=image.copy(),
|
||||
params=params,
|
||||
device=device,
|
||||
)
|
||||
|
||||
streamlit_util.display_and_download_audio(
|
||||
segment,
|
||||
name=Path(image_file.name).stem,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_image_to_audio()
|
|
@ -0,0 +1,273 @@
|
|||
import dataclasses
|
||||
import io
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
import streamlit as st
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.datatypes import InferenceInput, PromptInput
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_interpolation() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":performing_arts: Interpolation")
|
||||
st.write(
|
||||
"""
|
||||
Interpolate between prompts in the latent space.
|
||||
"""
|
||||
)
|
||||
|
||||
with st.expander("Help", False):
|
||||
st.write(
|
||||
"""
|
||||
This tool allows specifying two endpoints and generating a long-form interpolation
|
||||
between them that traverses the latent space. The interpolation is generated by
|
||||
the method described at https://www.riffusion.com/about. A seed image is used to
|
||||
set the beat and tempo of the generated audio, and can be set in the sidebar.
|
||||
Usually the seed is changed or the prompt, but not both at once. You can browse
|
||||
infinite variations of the same prompt by changing the seed.
|
||||
|
||||
For example, try going from "church bells" to "jazz" with 10 steps and 0.75 denoising.
|
||||
This will generate a 50 second clip at 5 seconds per step. Then play with the seeds
|
||||
or denoising to get different variations.
|
||||
"""
|
||||
)
|
||||
|
||||
# Sidebar params
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
|
||||
num_interpolation_steps = T.cast(
|
||||
int,
|
||||
st.sidebar.number_input(
|
||||
"Interpolation steps",
|
||||
value=4,
|
||||
min_value=1,
|
||||
max_value=20,
|
||||
help="Number of model generations between the two prompts. Controls the duration.",
|
||||
),
|
||||
)
|
||||
|
||||
num_inference_steps = T.cast(
|
||||
int,
|
||||
st.sidebar.number_input(
|
||||
"Steps per sample", value=50, help="Number of denoising steps per model run"
|
||||
),
|
||||
)
|
||||
|
||||
guidance = st.sidebar.number_input(
|
||||
"Guidance",
|
||||
value=7.0,
|
||||
help="How much the model listens to the text prompt",
|
||||
)
|
||||
|
||||
init_image_name = st.sidebar.selectbox(
|
||||
"Seed image",
|
||||
# TODO(hayk): Read from directory
|
||||
options=["og_beat", "agile", "marim", "motorway", "vibes", "custom"],
|
||||
index=0,
|
||||
help="Which seed image to use for img2img. Custom allows uploading your own.",
|
||||
)
|
||||
assert init_image_name is not None
|
||||
if init_image_name == "custom":
|
||||
init_image_file = st.sidebar.file_uploader(
|
||||
"Upload a custom seed image",
|
||||
type=streamlit_util.IMAGE_EXTENSIONS,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
if init_image_file:
|
||||
st.sidebar.image(init_image_file)
|
||||
|
||||
show_individual_outputs = st.sidebar.checkbox(
|
||||
"Show individual outputs",
|
||||
value=False,
|
||||
help="Show each model output",
|
||||
)
|
||||
show_images = st.sidebar.checkbox(
|
||||
"Show individual images",
|
||||
value=False,
|
||||
help="Show each generated image",
|
||||
)
|
||||
|
||||
# Prompt inputs A and B in two columns
|
||||
|
||||
with st.form(key="interpolation_form"):
|
||||
left, right = st.columns(2)
|
||||
|
||||
with left:
|
||||
st.write("##### Prompt A")
|
||||
prompt_input_a = PromptInput(guidance=guidance, **get_prompt_inputs(key="a"))
|
||||
|
||||
with right:
|
||||
st.write("##### Prompt B")
|
||||
prompt_input_b = PromptInput(guidance=guidance, **get_prompt_inputs(key="b"))
|
||||
|
||||
st.form_submit_button("Generate", type="primary")
|
||||
|
||||
if not prompt_input_a.prompt or not prompt_input_b.prompt:
|
||||
st.info("Enter both prompts to interpolate between them")
|
||||
return
|
||||
|
||||
alphas = list(np.linspace(0, 1, num_interpolation_steps))
|
||||
alphas_str = ", ".join([f"{alpha:.2f}" for alpha in alphas])
|
||||
st.write(f"**Alphas** : [{alphas_str}]")
|
||||
|
||||
# TODO(hayk): Apply scaling to alphas like this
|
||||
# T_shifted = T * 2 - 1
|
||||
# T_sample = (np.abs(T_shifted)**t_scale_power * np.sign(T_shifted) + 1) / 2
|
||||
# T_sample = T_sample * (t_end - t_start) + t_start
|
||||
|
||||
if init_image_name == "custom":
|
||||
if not init_image_file:
|
||||
st.info("Upload a custom seed image")
|
||||
return
|
||||
init_image = Image.open(init_image_file).convert("RGB")
|
||||
else:
|
||||
init_image_path = (
|
||||
Path(__file__).parent.parent.parent.parent / "seed_images" / f"{init_image_name}.png"
|
||||
)
|
||||
init_image = Image.open(str(init_image_path)).convert("RGB")
|
||||
|
||||
# TODO(hayk): Move this code into a shared place and add to riffusion.cli
|
||||
image_list: T.List[Image.Image] = []
|
||||
audio_bytes_list: T.List[io.BytesIO] = []
|
||||
for i, alpha in enumerate(alphas):
|
||||
inputs = InferenceInput(
|
||||
alpha=float(alpha),
|
||||
num_inference_steps=num_inference_steps,
|
||||
seed_image_id="og_beat",
|
||||
start=prompt_input_a,
|
||||
end=prompt_input_b,
|
||||
)
|
||||
|
||||
if i == 0:
|
||||
with st.expander("Example input JSON", expanded=False):
|
||||
st.json(dataclasses.asdict(inputs))
|
||||
|
||||
image, audio_bytes = run_interpolation(
|
||||
inputs=inputs,
|
||||
init_image=init_image,
|
||||
device=device,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
if show_individual_outputs:
|
||||
st.write(f"#### ({i + 1} / {len(alphas)}) Alpha={alpha:.2f}")
|
||||
if show_images:
|
||||
st.image(image)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
image_list.append(image)
|
||||
audio_bytes_list.append(audio_bytes)
|
||||
|
||||
st.write("#### Final Output")
|
||||
|
||||
# TODO(hayk): Concatenate with overlap and better blending like in audio to audio
|
||||
audio_segments = [pydub.AudioSegment.from_file(audio_bytes) for audio_bytes in audio_bytes_list]
|
||||
concat_segment = audio_segments[0]
|
||||
for segment in audio_segments[1:]:
|
||||
concat_segment = concat_segment.append(segment, crossfade=0)
|
||||
|
||||
audio_bytes = io.BytesIO()
|
||||
concat_segment.export(audio_bytes, format=extension)
|
||||
audio_bytes.seek(0)
|
||||
|
||||
st.write(f"Duration: {concat_segment.duration_seconds:.3f} seconds")
|
||||
st.audio(audio_bytes)
|
||||
|
||||
output_name = (
|
||||
f"{prompt_input_a.prompt.replace(' ', '_')}_"
|
||||
f"{prompt_input_b.prompt.replace(' ', '_')}.{extension}"
|
||||
)
|
||||
st.download_button(
|
||||
output_name,
|
||||
data=audio_bytes,
|
||||
file_name=output_name,
|
||||
mime=f"audio/{extension}",
|
||||
)
|
||||
|
||||
|
||||
def get_prompt_inputs(
|
||||
key: str,
|
||||
include_negative_prompt: bool = False,
|
||||
cols: bool = False,
|
||||
) -> T.Dict[str, T.Any]:
|
||||
"""
|
||||
Compute prompt inputs from widgets.
|
||||
"""
|
||||
p: T.Dict[str, T.Any] = {}
|
||||
|
||||
# Optionally use columns
|
||||
left, right = T.cast(T.Any, st.columns(2) if cols else (st, st))
|
||||
|
||||
visibility = "visible" if cols else "collapsed"
|
||||
p["prompt"] = left.text_input("Prompt", label_visibility=visibility, key=f"prompt_{key}")
|
||||
|
||||
if include_negative_prompt:
|
||||
p["negative_prompt"] = right.text_input("Negative Prompt", key=f"negative_prompt_{key}")
|
||||
|
||||
p["seed"] = T.cast(
|
||||
int,
|
||||
left.number_input(
|
||||
"Seed",
|
||||
value=42,
|
||||
key=f"seed_{key}",
|
||||
help="Integer used to generate a random result. Vary this to explore alternatives.",
|
||||
),
|
||||
)
|
||||
|
||||
p["denoising"] = right.number_input(
|
||||
"Denoising",
|
||||
value=0.5,
|
||||
key=f"denoising_{key}",
|
||||
help="How much to modify the seed image",
|
||||
)
|
||||
|
||||
return p
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def run_interpolation(
|
||||
inputs: InferenceInput, init_image: Image.Image, device: str = "cuda", extension: str = "mp3"
|
||||
) -> T.Tuple[Image.Image, io.BytesIO]:
|
||||
"""
|
||||
Cached function for riffusion interpolation.
|
||||
"""
|
||||
pipeline = streamlit_util.load_riffusion_checkpoint(
|
||||
device=device,
|
||||
# No trace so we can have variable width
|
||||
no_traced_unet=True,
|
||||
)
|
||||
|
||||
image = pipeline.riffuse(
|
||||
inputs,
|
||||
init_image=init_image,
|
||||
mask_image=None,
|
||||
)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
params = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
# Reconstruct from image to audio
|
||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
||||
image=image,
|
||||
params=params,
|
||||
device=device,
|
||||
output_format=extension,
|
||||
)
|
||||
|
||||
return image, audio_bytes
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_interpolation()
|
|
@ -0,0 +1,131 @@
|
|||
import tempfile
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
import streamlit as st
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_sample_clips() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":paperclip: Sample Clips")
|
||||
st.write(
|
||||
"""
|
||||
Export short clips from an audio file.
|
||||
"""
|
||||
)
|
||||
|
||||
with st.expander("Help", False):
|
||||
st.write(
|
||||
"""
|
||||
This tool simply allows uploading an audio file and randomly sampling short clips
|
||||
from it. It's useful for generating a large number of short clips from a single
|
||||
audio file. Outputs can be saved to a given directory with a given audio extension.
|
||||
"""
|
||||
)
|
||||
|
||||
audio_file = st.file_uploader(
|
||||
"Upload a file",
|
||||
type=streamlit_util.AUDIO_EXTENSIONS,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
if not audio_file:
|
||||
st.info("Upload an audio file to get started")
|
||||
return
|
||||
|
||||
st.audio(audio_file)
|
||||
|
||||
segment = pydub.AudioSegment.from_file(audio_file)
|
||||
st.write(
|
||||
" \n".join(
|
||||
[
|
||||
f"**Duration**: {segment.duration_seconds:.3f} seconds",
|
||||
f"**Channels**: {segment.channels}",
|
||||
f"**Sample rate**: {segment.frame_rate} Hz",
|
||||
f"**Sample width**: {segment.sample_width} bytes",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
save_to_disk = st.sidebar.checkbox("Save to Disk", False)
|
||||
export_as_mono = st.sidebar.checkbox("Export as Mono", False)
|
||||
compute_spectrograms = st.sidebar.checkbox("Compute Spectrograms", False)
|
||||
|
||||
row = st.columns(4)
|
||||
num_clips = T.cast(int, row[0].number_input("Number of Clips", value=3))
|
||||
duration_ms = T.cast(int, row[1].number_input("Duration (ms)", value=5000))
|
||||
seed = T.cast(int, row[2].number_input("Seed", value=42))
|
||||
|
||||
counter = streamlit_util.StreamlitCounter()
|
||||
st.button("Sample Clips", type="primary", on_click=counter.increment)
|
||||
if counter.value == 0:
|
||||
return
|
||||
|
||||
# Optionally pick an output directory
|
||||
if save_to_disk:
|
||||
output_dir = tempfile.mkdtemp(prefix="sample_clips_")
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
st.info(f"Output directory: `{output_dir}`")
|
||||
|
||||
if compute_spectrograms:
|
||||
images_dir = output_path / "images"
|
||||
images_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if seed >= 0:
|
||||
np.random.seed(seed)
|
||||
|
||||
if export_as_mono and segment.channels > 1:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
if save_to_disk:
|
||||
st.info(f"Writing {num_clips} clip(s) to `{str(output_path)}`")
|
||||
|
||||
# TODO(hayk): Share code with riffusion.cli.sample_clips.
|
||||
segment_duration_ms = int(segment.duration_seconds * 1000)
|
||||
for i in range(num_clips):
|
||||
clip_start_ms = np.random.randint(0, segment_duration_ms - duration_ms)
|
||||
clip = segment[clip_start_ms : clip_start_ms + duration_ms]
|
||||
|
||||
clip_name = f"clip_{i}_start_{clip_start_ms}_ms_duration_{duration_ms}_ms"
|
||||
|
||||
st.write(f"#### Clip {i + 1} / {num_clips} -- `{clip_name}`")
|
||||
|
||||
streamlit_util.display_and_download_audio(
|
||||
clip,
|
||||
name=clip_name,
|
||||
extension=extension,
|
||||
)
|
||||
|
||||
if save_to_disk:
|
||||
clip_path = output_path / f"{clip_name}.{extension}"
|
||||
clip.export(clip_path, format=extension)
|
||||
|
||||
if compute_spectrograms:
|
||||
params = SpectrogramParams()
|
||||
|
||||
image = streamlit_util.spectrogram_image_from_audio(
|
||||
clip,
|
||||
params=params,
|
||||
device=device,
|
||||
)
|
||||
|
||||
st.image(image)
|
||||
|
||||
if save_to_disk:
|
||||
image_path = images_dir / f"{clip_name}.jpeg"
|
||||
image.save(image_path)
|
||||
|
||||
if save_to_disk:
|
||||
st.info(f"Wrote {num_clips} clip(s) to `{str(output_path)}`")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_sample_clips()
|
|
@ -0,0 +1,105 @@
|
|||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import pydub
|
||||
import streamlit as st
|
||||
|
||||
from riffusion.audio_splitter import split_audio
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
from riffusion.util import audio_util
|
||||
|
||||
|
||||
def render_split_audio() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":scissors: Audio Splitter")
|
||||
st.write(
|
||||
"""
|
||||
Split audio into individual instrument stems.
|
||||
"""
|
||||
)
|
||||
|
||||
with st.expander("Help", False):
|
||||
st.write(
|
||||
"""
|
||||
This tool allows uploading an audio file of arbitrary length and splits it into
|
||||
stems of vocals, drums, bass, and other. It does this using a deep network that
|
||||
sweeps over the audio in clips, extracts the stems, and then cross fades the clips
|
||||
back together to construct the full length stems. It's particularly useful in
|
||||
combination with audio_to_audio, for example to split and preserve vocals while
|
||||
modifying the rest of the track with a prompt. Or, to pull out drums to add later
|
||||
in a DAW.
|
||||
"""
|
||||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
|
||||
extension_options = ["mp3", "wav", "m4a", "ogg", "flac", "webm"]
|
||||
extension = st.sidebar.selectbox(
|
||||
"Output format",
|
||||
options=extension_options,
|
||||
index=extension_options.index("mp3"),
|
||||
)
|
||||
assert extension is not None
|
||||
|
||||
audio_file = st.file_uploader(
|
||||
"Upload audio",
|
||||
type=extension_options,
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
|
||||
stem_options = ["Vocals", "Drums", "Bass", "Guitar", "Piano", "Other"]
|
||||
recombine = st.sidebar.multiselect(
|
||||
"Recombine",
|
||||
options=stem_options,
|
||||
default=[],
|
||||
help="Recombine these stems at the end",
|
||||
)
|
||||
|
||||
if not audio_file:
|
||||
st.info("Upload audio to get started")
|
||||
return
|
||||
|
||||
st.write("#### Original")
|
||||
st.audio(audio_file)
|
||||
|
||||
counter = streamlit_util.StreamlitCounter()
|
||||
st.button("Split", type="primary", on_click=counter.increment)
|
||||
if counter.value == 0:
|
||||
return
|
||||
|
||||
segment = streamlit_util.load_audio_file(audio_file)
|
||||
|
||||
# Split
|
||||
stems = split_audio_cached(segment, device=device)
|
||||
|
||||
input_name = Path(audio_file.name).stem
|
||||
|
||||
# Display each
|
||||
for name in stem_options:
|
||||
stem = stems[name.lower()]
|
||||
st.write(f"#### Stem: {name}")
|
||||
|
||||
output_name = f"{input_name}_{name.lower()}"
|
||||
streamlit_util.display_and_download_audio(stem, output_name, extension=extension)
|
||||
|
||||
if recombine:
|
||||
recombine_lower = [r.lower() for r in recombine]
|
||||
segments = [s for name, s in stems.items() if name in recombine_lower]
|
||||
recombined = audio_util.overlay_segments(segments)
|
||||
|
||||
# Display
|
||||
st.write(f"#### Recombined: {', '.join(recombine)}")
|
||||
output_name = f"{input_name}_{'_'.join(recombine_lower)}"
|
||||
streamlit_util.display_and_download_audio(recombined, output_name, extension=extension)
|
||||
|
||||
|
||||
@st.cache
|
||||
def split_audio_cached(
|
||||
segment: pydub.AudioSegment, device: str = "cuda"
|
||||
) -> T.Dict[str, pydub.AudioSegment]:
|
||||
return split_audio(segment, device=device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_split_audio()
|
|
@ -0,0 +1,118 @@
|
|||
import typing as T
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
|
||||
def render_text_to_audio() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":pencil2: Text to Audio")
|
||||
st.write(
|
||||
"""
|
||||
Generate audio from text prompts.
|
||||
"""
|
||||
)
|
||||
|
||||
with st.expander("Help", False):
|
||||
st.write(
|
||||
"""
|
||||
This tool runs riffusion in the simplest text to image form to generate an audio
|
||||
clip from a text prompt. There is no seed image or interpolation here. This mode
|
||||
allows more diversity and creativity than when using a seed image, but it also
|
||||
leads to having less control. Play with the seed to get infinite variations.
|
||||
"""
|
||||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
extension = streamlit_util.select_audio_extension(st.sidebar)
|
||||
|
||||
lora_path = st.sidebar.text_input("Lora Path", "")
|
||||
lora_scale = st.sidebar.number_input("Lora Scale", value=1.0)
|
||||
|
||||
with st.form("Inputs"):
|
||||
prompt = st.text_input("Prompt")
|
||||
negative_prompt = st.text_input("Negative prompt")
|
||||
|
||||
row = st.columns(4)
|
||||
num_clips = T.cast(
|
||||
int,
|
||||
row[0].number_input(
|
||||
"Number of clips",
|
||||
value=1,
|
||||
min_value=1,
|
||||
max_value=25,
|
||||
help="How many outputs to generate (seed gets incremented)",
|
||||
),
|
||||
)
|
||||
starting_seed = T.cast(
|
||||
int,
|
||||
row[1].number_input(
|
||||
"Seed",
|
||||
value=42,
|
||||
help="Change this to generate different variations",
|
||||
),
|
||||
)
|
||||
|
||||
st.form_submit_button("Riff", type="primary")
|
||||
|
||||
with st.sidebar:
|
||||
num_inference_steps = T.cast(int, st.number_input("Inference steps", value=50))
|
||||
width = T.cast(int, st.number_input("Width", value=512))
|
||||
guidance = st.number_input(
|
||||
"Guidance", value=7.0, help="How much the model listens to the text prompt"
|
||||
)
|
||||
scheduler = st.selectbox(
|
||||
"Scheduler",
|
||||
options=streamlit_util.SCHEDULER_OPTIONS,
|
||||
index=0,
|
||||
help="Which diffusion scheduler to use",
|
||||
)
|
||||
assert scheduler is not None
|
||||
|
||||
if not prompt:
|
||||
st.info("Enter a prompt")
|
||||
return
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
params = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
seed = starting_seed
|
||||
for i in range(1, num_clips + 1):
|
||||
st.write(f"#### Riff {i} / {num_clips} - Seed {seed}")
|
||||
|
||||
image = streamlit_util.run_txt2img(
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance=guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
width=width,
|
||||
height=512,
|
||||
device=device,
|
||||
scheduler=scheduler,
|
||||
lora_path=lora_path,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
st.image(image)
|
||||
|
||||
segment = streamlit_util.audio_segment_from_spectrogram_image(
|
||||
image=image,
|
||||
params=params,
|
||||
device=device,
|
||||
)
|
||||
|
||||
streamlit_util.display_and_download_audio(
|
||||
segment, name=f"{prompt.replace(' ', '_')}_{seed}", extension=extension
|
||||
)
|
||||
|
||||
seed += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_text_to_audio()
|
|
@ -0,0 +1,147 @@
|
|||
import json
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.streamlit import util as streamlit_util
|
||||
|
||||
# Example input json file to process in batch
|
||||
EXAMPLE_INPUT = """
|
||||
{
|
||||
"params": {
|
||||
"seed": 42,
|
||||
"num_inference_steps": 50,
|
||||
"guidance": 7.0,
|
||||
"width": 512,
|
||||
},
|
||||
"entries": [
|
||||
{
|
||||
"prompt": "Church bells"
|
||||
},
|
||||
{
|
||||
"prompt": "electronic beats",
|
||||
"negative_prompt": "drums"
|
||||
},
|
||||
{
|
||||
"prompt": "classical violin concerto"
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def render_text_to_audio_batch() -> None:
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.subheader(":scroll: Text to Audio Batch")
|
||||
st.write(
|
||||
"""
|
||||
Generate audio in batch from a JSON file of text prompts.
|
||||
"""
|
||||
)
|
||||
|
||||
with st.expander("Help", False):
|
||||
st.write(
|
||||
"""
|
||||
This tool is a batch form of text_to_audio, where the inputs are read in from a JSON
|
||||
file. The input file contains a global params block and a list of entries with positive
|
||||
and negative prompts. It's useful for automating a larger set of generations. See the
|
||||
example inputs below for the format of the file.
|
||||
"""
|
||||
)
|
||||
|
||||
device = streamlit_util.select_device(st.sidebar)
|
||||
|
||||
# Upload a JSON file
|
||||
json_file = st.file_uploader(
|
||||
"JSON file",
|
||||
type=["json"],
|
||||
label_visibility="collapsed",
|
||||
)
|
||||
|
||||
# Handle the null case
|
||||
if json_file is None:
|
||||
st.info("Upload a JSON file containing params and prompts")
|
||||
with st.expander("Example inputs.json", expanded=False):
|
||||
st.code(EXAMPLE_INPUT)
|
||||
return
|
||||
|
||||
# Read in and print it
|
||||
data = json.loads(json_file.read())
|
||||
with st.expander("Input Data", expanded=False):
|
||||
st.json(data)
|
||||
|
||||
params = data["params"]
|
||||
entries = data["entries"]
|
||||
|
||||
show_images = st.sidebar.checkbox("Show Images", False)
|
||||
|
||||
# Optionally specify an output directory
|
||||
output_dir = st.sidebar.text_input("Output Directory", "")
|
||||
output_path: T.Optional[Path] = None
|
||||
if output_dir:
|
||||
output_path = Path(output_dir)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for i, entry in enumerate(entries):
|
||||
st.write(f"#### Entry {i + 1} / {len(entries)}")
|
||||
|
||||
negative_prompt = entry.get("negative_prompt", None)
|
||||
|
||||
st.write(f"**Prompt**: {entry['prompt']} \n" + f"**Negative prompt**: {negative_prompt}")
|
||||
|
||||
image = streamlit_util.run_txt2img(
|
||||
prompt=entry["prompt"],
|
||||
negative_prompt=negative_prompt,
|
||||
seed=params.get("seed", 42),
|
||||
num_inference_steps=params.get("num_inference_steps", 50),
|
||||
guidance=params.get("guidance", 7.0),
|
||||
width=params.get("width", 512),
|
||||
height=512,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if show_images:
|
||||
st.image(image)
|
||||
|
||||
# TODO(hayk): Change the frequency range to [20, 20k] once the model is retrained
|
||||
p_spectrogram = SpectrogramParams(
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
output_format = "mp3"
|
||||
audio_bytes = streamlit_util.audio_bytes_from_spectrogram_image(
|
||||
image=image,
|
||||
params=p_spectrogram,
|
||||
device=device,
|
||||
output_format=output_format,
|
||||
)
|
||||
st.audio(audio_bytes)
|
||||
|
||||
if output_path:
|
||||
prompt_slug = entry["prompt"].replace(" ", "_")
|
||||
negative_prompt_slug = entry.get("negative_prompt", "").replace(" ", "_")
|
||||
|
||||
image_path = output_path / f"image_{i}_{prompt_slug}_neg_{negative_prompt_slug}.jpg"
|
||||
image.save(image_path, format="JPEG")
|
||||
entry["image_path"] = str(image_path)
|
||||
|
||||
audio_path = (
|
||||
output_path / f"audio_{i}_{prompt_slug}_neg_{negative_prompt_slug}.{output_format}"
|
||||
)
|
||||
audio_path.write_bytes(audio_bytes.getbuffer())
|
||||
entry["audio_path"] = str(audio_path)
|
||||
|
||||
if output_path:
|
||||
output_json_path = output_path / "index.json"
|
||||
output_json_path.write_text(json.dumps(data, indent=4))
|
||||
st.info(f"Output written to {str(output_path)}")
|
||||
else:
|
||||
st.info("Enter output directory in sidebar to save to disk")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_text_to_audio_batch()
|
|
@ -0,0 +1,43 @@
|
|||
import streamlit as st
|
||||
|
||||
|
||||
def render_main():
|
||||
st.set_page_config(layout="wide", page_icon="🎸")
|
||||
|
||||
st.title(":guitar: Riffusion Playground")
|
||||
|
||||
left, right = st.columns(2)
|
||||
|
||||
with left:
|
||||
create_link(":pencil2: Text to Audio", "/text_to_audio")
|
||||
st.write("Generate audio clips from text prompts.")
|
||||
|
||||
create_link(":wave: Audio to Audio", "/audio_to_audio")
|
||||
st.write("Upload audio and modify with text prompt (interpolation supported).")
|
||||
|
||||
create_link(":performing_arts: Interpolation", "/interpolation")
|
||||
st.write("Interpolate between prompts in the latent space.")
|
||||
|
||||
create_link(":scissors: Audio Splitter", "/split_audio")
|
||||
st.write("Split audio into stems like vocals, bass, drums, guitar, etc.")
|
||||
|
||||
with right:
|
||||
create_link(":scroll: Text to Audio Batch", "/text_to_audio_batch")
|
||||
st.write("Generate audio in batch from a JSON file of text prompts.")
|
||||
|
||||
create_link(":paperclip: Sample Clips", "/sample_clips")
|
||||
st.write("Export short clips from an audio file.")
|
||||
|
||||
create_link(":musical_keyboard: Image to Audio", "/image_to_audio")
|
||||
st.write("Reconstruct audio from spectrogram images.")
|
||||
|
||||
|
||||
def create_link(name: str, url: str) -> None:
|
||||
st.markdown(
|
||||
f"### <a href='{url}' target='_self' style='text-decoration: none;'>{name}</a>",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
render_main()
|
|
@ -0,0 +1,457 @@
|
|||
"""
|
||||
Streamlit utilities (mostly cached wrappers around riffusion code).
|
||||
"""
|
||||
import io
|
||||
import threading
|
||||
import typing as T
|
||||
from pathlib import Path
|
||||
|
||||
import pydub
|
||||
import streamlit as st
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.audio_splitter import AudioSplitter
|
||||
from riffusion.riffusion_pipeline import RiffusionPipeline
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
# TODO(hayk): Add URL params
|
||||
|
||||
AUDIO_EXTENSIONS = ["mp3", "wav", "flac", "webm", "m4a", "ogg"]
|
||||
IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"]
|
||||
|
||||
SCHEDULER_OPTIONS = [
|
||||
"PNDMScheduler",
|
||||
"DDIMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
]
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_riffusion_checkpoint(
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
no_traced_unet: bool = False,
|
||||
device: str = "cuda",
|
||||
) -> RiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion pipeline.
|
||||
"""
|
||||
return RiffusionPipeline.load_checkpoint(
|
||||
checkpoint=checkpoint,
|
||||
use_traced_unet=not no_traced_unet,
|
||||
device=device,
|
||||
)
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_stable_diffusion_pipeline(
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||
lora_path: T.Optional[str] = None,
|
||||
lora_scale: float = 1.0,
|
||||
) -> StableDiffusionPipeline:
|
||||
"""
|
||||
Load the riffusion pipeline.
|
||||
|
||||
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
|
||||
"""
|
||||
if device == "cpu" or device.lower().startswith("mps"):
|
||||
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
||||
dtype = torch.float32
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=dtype,
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
|
||||
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
|
||||
|
||||
if lora_path:
|
||||
if not Path(lora_path).is_file() or Path(lora_path).is_dir():
|
||||
raise RuntimeError("Bad lora path")
|
||||
|
||||
from lora_diffusion import patch_pipe, tune_lora_scale
|
||||
|
||||
patch_pipe(
|
||||
pipeline,
|
||||
lora_path,
|
||||
patch_text=True,
|
||||
patch_ti=True,
|
||||
patch_unet=True,
|
||||
)
|
||||
tune_lora_scale(pipeline.unet, lora_scale)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
def get_scheduler(scheduler: str, config: T.Any) -> T.Any:
|
||||
"""
|
||||
Construct a denoising scheduler from a string.
|
||||
"""
|
||||
if scheduler == "PNDMScheduler":
|
||||
from diffusers import PNDMScheduler
|
||||
|
||||
return PNDMScheduler.from_config(config)
|
||||
elif scheduler == "DPMSolverMultistepScheduler":
|
||||
from diffusers import DPMSolverMultistepScheduler
|
||||
|
||||
return DPMSolverMultistepScheduler.from_config(config)
|
||||
elif scheduler == "DDIMScheduler":
|
||||
from diffusers import DDIMScheduler
|
||||
|
||||
return DDIMScheduler.from_config(config)
|
||||
elif scheduler == "LMSDiscreteScheduler":
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
|
||||
return LMSDiscreteScheduler.from_config(config)
|
||||
elif scheduler == "EulerDiscreteScheduler":
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
|
||||
return EulerDiscreteScheduler.from_config(config)
|
||||
elif scheduler == "EulerAncestralDiscreteScheduler":
|
||||
from diffusers import EulerAncestralDiscreteScheduler
|
||||
|
||||
return EulerAncestralDiscreteScheduler.from_config(config)
|
||||
else:
|
||||
raise ValueError(f"Unknown scheduler {scheduler}")
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def pipeline_lock() -> threading.Lock:
|
||||
"""
|
||||
Singleton lock used to prevent concurrent access to any model pipeline.
|
||||
"""
|
||||
return threading.Lock()
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_stable_diffusion_img2img_pipeline(
|
||||
checkpoint: str = "riffusion/riffusion-model-v1",
|
||||
device: str = "cuda",
|
||||
dtype: torch.dtype = torch.float16,
|
||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||
lora_path: T.Optional[str] = None,
|
||||
lora_scale: float = 1.0,
|
||||
) -> StableDiffusionImg2ImgPipeline:
|
||||
"""
|
||||
Load the image to image pipeline.
|
||||
|
||||
TODO(hayk): Merge this into RiffusionPipeline to just load one model.
|
||||
"""
|
||||
if device == "cpu" or device.lower().startswith("mps"):
|
||||
print(f"WARNING: Falling back to float32 on {device}, float16 is unsupported")
|
||||
dtype = torch.float32
|
||||
|
||||
pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
checkpoint,
|
||||
revision="main",
|
||||
torch_dtype=dtype,
|
||||
safety_checker=lambda images, **kwargs: (images, False),
|
||||
).to(device)
|
||||
|
||||
pipeline.scheduler = get_scheduler(scheduler, config=pipeline.scheduler.config)
|
||||
|
||||
# TODO reduce duplication
|
||||
if lora_path:
|
||||
if not Path(lora_path).is_file() or Path(lora_path).is_dir():
|
||||
raise RuntimeError("Bad lora path")
|
||||
|
||||
from lora_diffusion import patch_pipe, tune_lora_scale
|
||||
|
||||
patch_pipe(
|
||||
pipeline,
|
||||
lora_path,
|
||||
patch_text=True,
|
||||
patch_ti=True,
|
||||
patch_unet=True,
|
||||
)
|
||||
tune_lora_scale(pipeline.unet, lora_scale)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def run_txt2img(
|
||||
prompt: str,
|
||||
num_inference_steps: int,
|
||||
guidance: float,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
width: int,
|
||||
height: int,
|
||||
device: str = "cuda",
|
||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||
lora_path: T.Optional[str] = None,
|
||||
lora_scale: float = 1.0,
|
||||
) -> Image.Image:
|
||||
"""
|
||||
Run the text to image pipeline with caching.
|
||||
"""
|
||||
with pipeline_lock():
|
||||
pipeline = load_stable_diffusion_pipeline(
|
||||
device=device,
|
||||
scheduler=scheduler,
|
||||
lora_path=lora_path,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
generator_device = "cpu" if device.lower().startswith("mps") else device
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance,
|
||||
negative_prompt=negative_prompt or None,
|
||||
generator=generator,
|
||||
width=width,
|
||||
height=height,
|
||||
)
|
||||
|
||||
return output["images"][0]
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def spectrogram_image_converter(
|
||||
params: SpectrogramParams,
|
||||
device: str = "cuda",
|
||||
) -> SpectrogramImageConverter:
|
||||
return SpectrogramImageConverter(params=params, device=device)
|
||||
|
||||
|
||||
@st.cache
|
||||
def spectrogram_image_from_audio(
|
||||
segment: pydub.AudioSegment,
|
||||
params: SpectrogramParams,
|
||||
device: str = "cuda",
|
||||
) -> Image.Image:
|
||||
converter = spectrogram_image_converter(params=params, device=device)
|
||||
return converter.spectrogram_image_from_audio(segment)
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def audio_segment_from_spectrogram_image(
|
||||
image: Image.Image,
|
||||
params: SpectrogramParams,
|
||||
device: str = "cuda",
|
||||
) -> pydub.AudioSegment:
|
||||
converter = spectrogram_image_converter(params=params, device=device)
|
||||
return converter.audio_from_spectrogram_image(image)
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def audio_bytes_from_spectrogram_image(
|
||||
image: Image.Image,
|
||||
params: SpectrogramParams,
|
||||
device: str = "cuda",
|
||||
output_format: str = "mp3",
|
||||
) -> io.BytesIO:
|
||||
segment = audio_segment_from_spectrogram_image(image=image, params=params, device=device)
|
||||
|
||||
audio_bytes = io.BytesIO()
|
||||
segment.export(audio_bytes, format=output_format)
|
||||
|
||||
return audio_bytes
|
||||
|
||||
|
||||
def select_device(container: T.Any = st.sidebar) -> str:
|
||||
"""
|
||||
Dropdown to select a torch device, with an intelligent default.
|
||||
"""
|
||||
default_device = "cpu"
|
||||
if torch.cuda.is_available():
|
||||
default_device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
default_device = "mps"
|
||||
|
||||
device_options = ["cuda", "cpu", "mps"]
|
||||
device = st.sidebar.selectbox(
|
||||
"Device",
|
||||
options=device_options,
|
||||
index=device_options.index(default_device),
|
||||
help="Which compute device to use. CUDA is recommended.",
|
||||
)
|
||||
assert device is not None
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def select_audio_extension(container: T.Any = st.sidebar) -> str:
|
||||
"""
|
||||
Dropdown to select an audio extension, with an intelligent default.
|
||||
"""
|
||||
default = "mp3" if pydub.AudioSegment.ffmpeg else "wav"
|
||||
extension = container.selectbox(
|
||||
"Output format",
|
||||
options=AUDIO_EXTENSIONS,
|
||||
index=AUDIO_EXTENSIONS.index(default),
|
||||
)
|
||||
assert extension is not None
|
||||
return extension
|
||||
|
||||
|
||||
def select_scheduler(container: T.Any = st.sidebar) -> str:
|
||||
"""
|
||||
Dropdown to select a scheduler.
|
||||
"""
|
||||
scheduler = st.sidebar.selectbox(
|
||||
"Scheduler",
|
||||
options=SCHEDULER_OPTIONS,
|
||||
index=0,
|
||||
help="Which diffusion scheduler to use",
|
||||
)
|
||||
assert scheduler is not None
|
||||
return scheduler
|
||||
|
||||
|
||||
@st.experimental_memo
|
||||
def load_audio_file(audio_file: io.BytesIO) -> pydub.AudioSegment:
|
||||
return pydub.AudioSegment.from_file(audio_file)
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def get_audio_splitter(device: str = "cuda"):
|
||||
return AudioSplitter(device=device)
|
||||
|
||||
|
||||
@st.experimental_singleton
|
||||
def load_magic_mix_pipeline(device: str = "cuda", scheduler: str = SCHEDULER_OPTIONS[0]):
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"riffusion/riffusion-model-v1",
|
||||
custom_pipeline="magic_mix",
|
||||
).to(device)
|
||||
|
||||
pipeline.scheduler = get_scheduler(scheduler, pipeline.scheduler.config)
|
||||
|
||||
return pipeline
|
||||
|
||||
|
||||
@st.cache
|
||||
def run_img2img_magic_mix(
|
||||
prompt: str,
|
||||
init_image: Image.Image,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
kmin: float,
|
||||
kmax: float,
|
||||
mix_factor: float,
|
||||
device: str = "cuda",
|
||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||
):
|
||||
"""
|
||||
Run the magic mix pipeline for img2img.
|
||||
"""
|
||||
with pipeline_lock():
|
||||
pipeline = load_magic_mix_pipeline(
|
||||
device=device,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
return pipeline(
|
||||
init_image,
|
||||
prompt=prompt,
|
||||
kmin=kmin,
|
||||
kmax=kmax,
|
||||
mix_factor=mix_factor,
|
||||
seed=seed,
|
||||
guidance_scale=guidance_scale,
|
||||
steps=num_inference_steps,
|
||||
)
|
||||
|
||||
|
||||
@st.cache
|
||||
def run_img2img(
|
||||
prompt: str,
|
||||
init_image: Image.Image,
|
||||
denoising_strength: float,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
negative_prompt: T.Optional[str] = None,
|
||||
device: str = "cuda",
|
||||
scheduler: str = SCHEDULER_OPTIONS[0],
|
||||
progress_callback: T.Optional[T.Callable[[float], T.Any]] = None,
|
||||
lora_path: T.Optional[str] = None,
|
||||
lora_scale: float = 1.0,
|
||||
) -> Image.Image:
|
||||
with pipeline_lock():
|
||||
pipeline = load_stable_diffusion_img2img_pipeline(
|
||||
device=device,
|
||||
scheduler=scheduler,
|
||||
lora_path=lora_path,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
|
||||
generator_device = "cpu" if device.lower().startswith("mps") else device
|
||||
generator = torch.Generator(device=generator_device).manual_seed(seed)
|
||||
|
||||
num_expected_steps = max(int(num_inference_steps * denoising_strength), 1)
|
||||
|
||||
def callback(step: int, tensor: torch.Tensor, foo: T.Any) -> None:
|
||||
if progress_callback is not None:
|
||||
progress_callback(step / num_expected_steps)
|
||||
|
||||
result = pipeline(
|
||||
prompt=prompt,
|
||||
image=init_image,
|
||||
strength=denoising_strength,
|
||||
num_inference_steps=num_inference_steps,
|
||||
guidance_scale=guidance_scale,
|
||||
negative_prompt=negative_prompt or None,
|
||||
num_images_per_prompt=1,
|
||||
generator=generator,
|
||||
callback=callback,
|
||||
callback_steps=1,
|
||||
)
|
||||
|
||||
return result.images[0]
|
||||
|
||||
|
||||
class StreamlitCounter:
|
||||
"""
|
||||
Simple counter stored in streamlit session state.
|
||||
"""
|
||||
|
||||
def __init__(self, key="_counter"):
|
||||
self.key = key
|
||||
if not st.session_state.get(self.key):
|
||||
st.session_state[self.key] = 0
|
||||
|
||||
def increment(self):
|
||||
st.session_state[self.key] += 1
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return st.session_state[self.key]
|
||||
|
||||
|
||||
def display_and_download_audio(
|
||||
segment: pydub.AudioSegment,
|
||||
name: str,
|
||||
extension: str = "mp3",
|
||||
) -> None:
|
||||
"""
|
||||
Display the given audio segment and provide a button to download it with
|
||||
a proper file name, since st.audio doesn't support that.
|
||||
"""
|
||||
mime_type = f"audio/{extension}"
|
||||
audio_bytes = io.BytesIO()
|
||||
segment.export(audio_bytes, format=extension)
|
||||
st.audio(audio_bytes, format=mime_type)
|
||||
|
||||
st.download_button(
|
||||
f"{name}.{extension}",
|
||||
data=audio_bytes,
|
||||
file_name=f"{name}.{extension}",
|
||||
mime=mime_type,
|
||||
)
|
|
@ -0,0 +1,99 @@
|
|||
"""
|
||||
Audio utility functions.
|
||||
"""
|
||||
|
||||
import io
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import pydub
|
||||
from scipy.io import wavfile
|
||||
|
||||
|
||||
def audio_from_waveform(
|
||||
samples: np.ndarray, sample_rate: int, normalize: bool = False
|
||||
) -> pydub.AudioSegment:
|
||||
"""
|
||||
Convert a numpy array of samples of a waveform to an audio segment.
|
||||
|
||||
Args:
|
||||
samples: (channels, samples) array
|
||||
"""
|
||||
# Normalize volume to fit in int16
|
||||
if normalize:
|
||||
samples *= np.iinfo(np.int16).max / np.max(np.abs(samples))
|
||||
|
||||
# Transpose and convert to int16
|
||||
samples = samples.transpose(1, 0)
|
||||
samples = samples.astype(np.int16)
|
||||
|
||||
# Write to the bytes of a WAV file
|
||||
wav_bytes = io.BytesIO()
|
||||
wavfile.write(wav_bytes, sample_rate, samples)
|
||||
wav_bytes.seek(0)
|
||||
|
||||
# Read into pydub
|
||||
return pydub.AudioSegment.from_wav(wav_bytes)
|
||||
|
||||
|
||||
def apply_filters(segment: pydub.AudioSegment, compression: bool = False) -> pydub.AudioSegment:
|
||||
"""
|
||||
Apply post-processing filters to the audio segment to compress it and
|
||||
keep at a -10 dBFS level.
|
||||
"""
|
||||
# TODO(hayk): Come up with a principled strategy for these filters and experiment end-to-end.
|
||||
# TODO(hayk): Is this going to make audio unbalanced between sequential clips?
|
||||
|
||||
if compression:
|
||||
segment = pydub.effects.normalize(
|
||||
segment,
|
||||
headroom=0.1,
|
||||
)
|
||||
|
||||
segment = segment.apply_gain(-10 - segment.dBFS)
|
||||
|
||||
# TODO(hayk): This is quite slow, ~1.7 seconds on a beefy CPU
|
||||
segment = pydub.effects.compress_dynamic_range(
|
||||
segment,
|
||||
threshold=-20.0,
|
||||
ratio=4.0,
|
||||
attack=5.0,
|
||||
release=50.0,
|
||||
)
|
||||
|
||||
desired_db = -12
|
||||
segment = segment.apply_gain(desired_db - segment.dBFS)
|
||||
|
||||
segment = pydub.effects.normalize(
|
||||
segment,
|
||||
headroom=0.1,
|
||||
)
|
||||
|
||||
return segment
|
||||
|
||||
|
||||
def stitch_segments(
|
||||
segments: T.Sequence[pydub.AudioSegment], crossfade_s: float
|
||||
) -> pydub.AudioSegment:
|
||||
"""
|
||||
Stitch together a sequence of audio segments with a crossfade between each segment.
|
||||
"""
|
||||
crossfade_ms = int(crossfade_s * 1000)
|
||||
combined_segment = segments[0]
|
||||
for segment in segments[1:]:
|
||||
combined_segment = combined_segment.append(segment, crossfade=crossfade_ms)
|
||||
return combined_segment
|
||||
|
||||
|
||||
def overlay_segments(segments: T.Sequence[pydub.AudioSegment]) -> pydub.AudioSegment:
|
||||
"""
|
||||
Overlay a sequence of audio segments on top of each other.
|
||||
"""
|
||||
assert len(segments) > 0
|
||||
output: pydub.AudioSegment = None
|
||||
for segment in segments:
|
||||
if output is None:
|
||||
output = segment
|
||||
else:
|
||||
output = output.overlay(segment)
|
||||
return output
|
|
@ -0,0 +1,9 @@
|
|||
import base64
|
||||
import io
|
||||
|
||||
|
||||
def encode(buffer: io.BytesIO) -> str:
|
||||
"""
|
||||
Encode the given buffer as base64.
|
||||
"""
|
||||
return base64.encodebytes(buffer.getvalue()).decode("ascii")
|
|
@ -0,0 +1,60 @@
|
|||
"""
|
||||
FFT tools to analyze frequency content of audio segments. This is not code for
|
||||
dealing with spectrogram images, but for analysis of waveforms.
|
||||
"""
|
||||
import struct
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
import plotly.graph_objects as go
|
||||
import pydub
|
||||
from scipy.fft import rfft, rfftfreq
|
||||
|
||||
|
||||
def plot_ffts(
|
||||
segments: T.Dict[str, pydub.AudioSegment],
|
||||
title: str = "FFT",
|
||||
min_frequency: float = 20,
|
||||
max_frequency: float = 20000,
|
||||
) -> None:
|
||||
"""
|
||||
Plot an FFT analysis of the given audio segments.
|
||||
"""
|
||||
ffts = {name: compute_fft(seg) for name, seg in segments.items()}
|
||||
|
||||
fig = go.Figure(
|
||||
data=[go.Scatter(x=data[0], y=data[1], name=name) for name, data in ffts.items()],
|
||||
layout={"title": title},
|
||||
)
|
||||
fig.update_xaxes(
|
||||
range=[np.log(min_frequency) / np.log(10), np.log(max_frequency) / np.log(10)],
|
||||
type="log",
|
||||
title="Frequency",
|
||||
)
|
||||
fig.update_yaxes(title="Value")
|
||||
fig.show()
|
||||
|
||||
|
||||
def compute_fft(sound: pydub.AudioSegment) -> T.Tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Compute the FFT of the given audio segment as a mono signal.
|
||||
|
||||
Returns:
|
||||
frequencies: FFT computed frequencies
|
||||
amplitudes: Amplitudes of each frequency
|
||||
"""
|
||||
# Convert to mono if needed.
|
||||
if sound.channels > 1:
|
||||
sound = sound.set_channels(1)
|
||||
|
||||
sample_rate = sound.frame_rate
|
||||
|
||||
num_samples = int(sound.frame_count())
|
||||
samples = struct.unpack(f"{num_samples * sound.channels}h", sound.raw_data)
|
||||
|
||||
fft_values = rfft(samples)
|
||||
amplitudes = np.abs(fft_values)
|
||||
|
||||
frequencies = rfftfreq(n=num_samples, d=1 / sample_rate)
|
||||
|
||||
return frequencies, amplitudes
|
|
@ -0,0 +1,122 @@
|
|||
"""
|
||||
Module for converting between spectrograms tensors and spectrogram images, as well as
|
||||
general helpers for operating on pillow images.
|
||||
"""
|
||||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
|
||||
def image_from_spectrogram(spectrogram: np.ndarray, power: float = 0.25) -> Image.Image:
|
||||
"""
|
||||
Compute a spectrogram image from a spectrogram magnitude array.
|
||||
|
||||
This is the inverse of spectrogram_from_image, except for discretization error from
|
||||
quantizing to uint8.
|
||||
|
||||
Args:
|
||||
spectrogram: (channels, frequency, time)
|
||||
power: A power curve to apply to the spectrogram to preserve contrast
|
||||
|
||||
Returns:
|
||||
image: (frequency, time, channels)
|
||||
"""
|
||||
# Rescale to 0-1
|
||||
max_value = np.max(spectrogram)
|
||||
data = spectrogram / max_value
|
||||
|
||||
# Apply the power curve
|
||||
data = np.power(data, power)
|
||||
|
||||
# Rescale to 0-255
|
||||
data = data * 255
|
||||
|
||||
# Invert
|
||||
data = 255 - data
|
||||
|
||||
# Convert to uint8
|
||||
data = data.astype(np.uint8)
|
||||
|
||||
# Munge channels into a PIL image
|
||||
if data.shape[0] == 1:
|
||||
# TODO(hayk): Do we want to write single channel to disk instead?
|
||||
image = Image.fromarray(data[0], mode="L").convert("RGB")
|
||||
elif data.shape[0] == 2:
|
||||
data = np.array([np.zeros_like(data[0]), data[0], data[1]]).transpose(1, 2, 0)
|
||||
image = Image.fromarray(data, mode="RGB")
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported number of channels: {data.shape[0]}")
|
||||
|
||||
# Flip Y
|
||||
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def spectrogram_from_image(
|
||||
image: Image.Image,
|
||||
power: float = 0.25,
|
||||
stereo: bool = False,
|
||||
max_value: float = 30e6,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute a spectrogram magnitude array from a spectrogram image.
|
||||
|
||||
This is the inverse of image_from_spectrogram, except for discretization error from
|
||||
quantizing to uint8.
|
||||
|
||||
Args:
|
||||
image: (frequency, time, channels)
|
||||
power: The power curve applied to the spectrogram
|
||||
stereo: Whether the spectrogram encodes stereo data
|
||||
max_value: The max value of the original spectrogram. In practice doesn't matter.
|
||||
|
||||
Returns:
|
||||
spectrogram: (channels, frequency, time)
|
||||
"""
|
||||
# Convert to RGB if single channel
|
||||
if image.mode in ("P", "L"):
|
||||
image = image.convert("RGB")
|
||||
|
||||
# Flip Y
|
||||
image = image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)
|
||||
|
||||
# Munge channels into a numpy array of (channels, frequency, time)
|
||||
data = np.array(image).transpose(2, 0, 1)
|
||||
if stereo:
|
||||
# Take the G and B channels as done in image_from_spectrogram
|
||||
data = data[[1, 2], :, :]
|
||||
else:
|
||||
data = data[0:1, :, :]
|
||||
|
||||
# Convert to floats
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# Invert
|
||||
data = 255 - data
|
||||
|
||||
# Rescale to 0-1
|
||||
data = data / 255
|
||||
|
||||
# Reverse the power curve
|
||||
data = np.power(data, 1 / power)
|
||||
|
||||
# Rescale to max value
|
||||
data = data * max_value
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def exif_from_image(pil_image: Image.Image) -> T.Dict[str, T.Any]:
|
||||
"""
|
||||
Get the EXIF data from a PIL image as a dict.
|
||||
"""
|
||||
exif = pil_image.getexif()
|
||||
|
||||
if exif is None or len(exif) == 0:
|
||||
return {}
|
||||
|
||||
return {SpectrogramParams.ExifTags(key).name: val for key, val in exif.items()}
|
|
@ -0,0 +1,48 @@
|
|||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def check_device(device: str, backup: str = "cpu") -> str:
|
||||
"""
|
||||
Check that the device is valid and available. If not,
|
||||
"""
|
||||
cuda_not_found = device.lower().startswith("cuda") and not torch.cuda.is_available()
|
||||
mps_not_found = device.lower().startswith("mps") and not torch.backends.mps.is_available()
|
||||
|
||||
if cuda_not_found or mps_not_found:
|
||||
warnings.warn(f"WARNING: {device} is not available, using {backup} instead.", stacklevel=3)
|
||||
return backup
|
||||
|
||||
return device
|
||||
|
||||
|
||||
def slerp(
|
||||
t: float, v0: torch.Tensor, v1: torch.Tensor, dot_threshold: float = 0.9995
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Helper function to spherically interpolate two arrays v1 v2.
|
||||
"""
|
||||
if not isinstance(v0, np.ndarray):
|
||||
inputs_are_torch = True
|
||||
input_device = v0.device
|
||||
v0 = v0.cpu().numpy()
|
||||
v1 = v1.cpu().numpy()
|
||||
|
||||
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
|
||||
if np.abs(dot) > dot_threshold:
|
||||
v2 = (1 - t) * v0 + t * v1
|
||||
else:
|
||||
theta_0 = np.arccos(dot)
|
||||
sin_theta_0 = np.sin(theta_0)
|
||||
theta_t = theta_0 * t
|
||||
sin_theta_t = np.sin(theta_t)
|
||||
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
|
||||
s1 = sin_theta_t / sin_theta_0
|
||||
v2 = s0 * v0 + s1 * v1
|
||||
|
||||
if inputs_are_torch:
|
||||
v2 = torch.from_numpy(v2).to(input_device)
|
||||
|
||||
return v2
|
After Width: | Height: | Size: 132 KiB |
After Width: | Height: | Size: 128 KiB |
After Width: | Height: | Size: 7.1 KiB |
After Width: | Height: | Size: 1.6 KiB |
After Width: | Height: | Size: 1.6 KiB |
After Width: | Height: | Size: 1.6 KiB |
After Width: | Height: | Size: 222 B |
After Width: | Height: | Size: 225 B |
After Width: | Height: | Size: 123 KiB |
After Width: | Height: | Size: 108 KiB |
After Width: | Height: | Size: 130 KiB |
|
@ -0,0 +1,99 @@
|
|||
import typing as T
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.cli import audio_to_image
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class AudioToImageTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli audio-to-image
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def default_params(cls) -> T.Dict:
|
||||
return dict(
|
||||
step_size_ms=10,
|
||||
num_frequencies=512,
|
||||
# TODO(hayk): Change these to [20, 20000] once a model is updated
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
stereo=False,
|
||||
device=cls.DEVICE,
|
||||
)
|
||||
|
||||
def test_audio_to_image(self) -> None:
|
||||
"""
|
||||
Test audio-to-image with default params.
|
||||
"""
|
||||
params = self.default_params()
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def test_stereo(self) -> None:
|
||||
"""
|
||||
Test audio-to-image with stereo=True.
|
||||
"""
|
||||
params = self.default_params()
|
||||
params["stereo"] = True
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def helper_test_with_params(self, params: T.Dict) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
output_dir = self.get_tmp_dir("audio_to_image_")
|
||||
|
||||
if params["stereo"]:
|
||||
stem = f"{audio_path.stem}_stereo"
|
||||
else:
|
||||
stem = audio_path.stem
|
||||
|
||||
image_path = output_dir / f"{stem}.png"
|
||||
|
||||
audio_to_image(audio=str(audio_path), image=str(image_path), **params)
|
||||
|
||||
# Check that the image exists
|
||||
self.assertTrue(image_path.exists())
|
||||
|
||||
pil_image = Image.open(image_path)
|
||||
|
||||
# Check the image mode
|
||||
self.assertEqual(pil_image.mode, "RGB")
|
||||
|
||||
# Check the image dimensions
|
||||
duration_ms = 5678
|
||||
self.assertTrue(str(duration_ms) in audio_path.name)
|
||||
expected_image_width = round(duration_ms / params["step_size_ms"])
|
||||
self.assertEqual(pil_image.width, expected_image_width)
|
||||
self.assertEqual(pil_image.height, params["num_frequencies"])
|
||||
|
||||
# Get channels as numpy arrays
|
||||
channels = [np.array(pil_image.getchannel(i)) for i in range(len(pil_image.getbands()))]
|
||||
self.assertEqual(len(channels), 3)
|
||||
|
||||
if params["stereo"]:
|
||||
# Check that the first channel is zero
|
||||
self.assertTrue(np.all(channels[0] == 0))
|
||||
else:
|
||||
# Check that all channels are the same
|
||||
self.assertTrue(np.all(channels[0] == channels[1]))
|
||||
self.assertTrue(np.all(channels[0] == channels[2]))
|
||||
|
||||
# Check that the image has exif data
|
||||
exif = pil_image.getexif()
|
||||
self.assertIsNotNone(exif)
|
||||
params_from_exif = SpectrogramParams.from_exif(exif)
|
||||
expected_params = SpectrogramParams(
|
||||
stereo=params["stereo"],
|
||||
step_size_ms=params["step_size_ms"],
|
||||
num_frequencies=params["num_frequencies"],
|
||||
max_frequency=params["max_frequency"],
|
||||
)
|
||||
self.assertTrue(params_from_exif == expected_params)
|
|
@ -0,0 +1,71 @@
|
|||
from pathlib import Path
|
||||
|
||||
import pydub
|
||||
|
||||
from riffusion.cli import image_to_audio
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class ImageToAudioTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli image-to-audio
|
||||
"""
|
||||
|
||||
def test_image_to_audio_mono(self) -> None:
|
||||
self.helper_image_to_audio(
|
||||
song_dir=self.TEST_DATA_PATH / "tired_traveler",
|
||||
clip_name="clip_2_start_103694_ms_duration_5678_ms",
|
||||
stereo=False,
|
||||
)
|
||||
|
||||
def test_image_to_audio_stereo(self) -> None:
|
||||
self.helper_image_to_audio(
|
||||
song_dir=self.TEST_DATA_PATH / "tired_traveler",
|
||||
clip_name="clip_2_start_103694_ms_duration_5678_ms",
|
||||
stereo=True,
|
||||
)
|
||||
|
||||
def helper_image_to_audio(self, song_dir: Path, clip_name: str, stereo: bool) -> None:
|
||||
if stereo:
|
||||
image_stem = clip_name + "_stereo"
|
||||
else:
|
||||
image_stem = clip_name
|
||||
|
||||
image_path = song_dir / "images" / f"{image_stem}.png"
|
||||
output_dir = self.get_tmp_dir("image_to_audio_")
|
||||
audio_path = output_dir / f"{image_path.stem}.wav"
|
||||
|
||||
image_to_audio(
|
||||
image=str(image_path),
|
||||
audio=str(audio_path),
|
||||
device=self.DEVICE,
|
||||
)
|
||||
|
||||
# Check that the audio exists
|
||||
self.assertTrue(audio_path.exists())
|
||||
|
||||
# Load the reconstructed audio and the original clip
|
||||
segment = pydub.AudioSegment.from_file(str(audio_path))
|
||||
expected_segment = pydub.AudioSegment.from_file(
|
||||
str(song_dir / "clips" / f"{clip_name}.wav")
|
||||
)
|
||||
|
||||
# Check sample rate
|
||||
self.assertEqual(segment.frame_rate, expected_segment.frame_rate)
|
||||
|
||||
# Check duration
|
||||
actual_duration_ms = round(segment.duration_seconds * 1000)
|
||||
expected_duration_ms = round(expected_segment.duration_seconds * 1000)
|
||||
self.assertTrue(abs(actual_duration_ms - expected_duration_ms) < 10)
|
||||
|
||||
# Check the number of channels
|
||||
self.assertEqual(expected_segment.channels, 2)
|
||||
if stereo:
|
||||
self.assertEqual(segment.channels, 2)
|
||||
else:
|
||||
self.assertEqual(segment.channels, 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TestCase.main()
|
|
@ -0,0 +1,65 @@
|
|||
import numpy as np
|
||||
import pydub
|
||||
|
||||
from riffusion.spectrogram_converter import SpectrogramConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import image_util
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class ImageUtilTest(TestCase):
|
||||
"""
|
||||
Test riffusion.util.image_util
|
||||
"""
|
||||
|
||||
def test_spectrogram_to_image_round_trip(self) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
|
||||
# Load up the audio file
|
||||
segment = pydub.AudioSegment.from_file(audio_path)
|
||||
|
||||
# Convert to mono
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
# Compute a spectrogram with default params
|
||||
params = SpectrogramParams(sample_rate=segment.frame_rate)
|
||||
converter = SpectrogramConverter(params=params, device=self.DEVICE)
|
||||
spectrogram = converter.spectrogram_from_audio(segment)
|
||||
|
||||
# Compute the image from the spectrogram
|
||||
image = image_util.image_from_spectrogram(
|
||||
spectrogram=spectrogram,
|
||||
power=params.power_for_image,
|
||||
)
|
||||
|
||||
# Save the max value
|
||||
max_value = np.max(spectrogram)
|
||||
|
||||
# Compute the spectrogram from the image
|
||||
spectrogram_reversed = image_util.spectrogram_from_image(
|
||||
image=image,
|
||||
max_value=max_value,
|
||||
power=params.power_for_image,
|
||||
stereo=params.stereo,
|
||||
)
|
||||
|
||||
# Check the shapes
|
||||
self.assertEqual(spectrogram.shape, spectrogram_reversed.shape)
|
||||
|
||||
# Check the max values
|
||||
self.assertEqual(np.max(spectrogram), np.max(spectrogram_reversed))
|
||||
|
||||
# Check the median values
|
||||
self.assertTrue(
|
||||
np.allclose(np.median(spectrogram), np.median(spectrogram_reversed), rtol=0.05)
|
||||
)
|
||||
|
||||
# Make sure all values are somewhat similar, but allow for discretization error
|
||||
# TODO(hayk): Investigate error more closely
|
||||
self.assertTrue(np.allclose(spectrogram, spectrogram_reversed, rtol=0.15))
|
|
@ -0,0 +1,24 @@
|
|||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class LinterTest(TestCase):
|
||||
"""
|
||||
Test that ruff, black, and mypy run cleanly.
|
||||
"""
|
||||
|
||||
HOME = Path(__file__).parent.parent
|
||||
|
||||
def test_ruff(self) -> None:
|
||||
code = subprocess.check_call(["ruff", str(self.HOME)])
|
||||
self.assertEqual(code, 0)
|
||||
|
||||
def test_black(self) -> None:
|
||||
code = subprocess.check_call(["black", "--check", str(self.HOME)])
|
||||
self.assertEqual(code, 0)
|
||||
|
||||
def test_mypy(self) -> None:
|
||||
code = subprocess.check_call(["mypy", str(self.HOME)])
|
||||
self.assertEqual(code, 0)
|
|
@ -0,0 +1,32 @@
|
|||
import contextlib
|
||||
import io
|
||||
|
||||
from riffusion.cli import print_exif
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class PrintExifTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli print-exif
|
||||
"""
|
||||
|
||||
def test_print_exif(self) -> None:
|
||||
"""
|
||||
Test print-exif.
|
||||
"""
|
||||
image_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "images"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.png"
|
||||
)
|
||||
|
||||
# Redirect stdout
|
||||
stdout = io.StringIO()
|
||||
with contextlib.redirect_stdout(stdout):
|
||||
print_exif(image=str(image_path))
|
||||
|
||||
# Check that a couple of values are printed
|
||||
self.assertTrue("NUM_FREQUENCIES = 512" in stdout.getvalue())
|
||||
self.assertTrue("SAMPLE_RATE = 44100" in stdout.getvalue())
|
|
@ -0,0 +1,88 @@
|
|||
import typing as T
|
||||
|
||||
import pydub
|
||||
|
||||
from riffusion.cli import sample_clips
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class SampleClipsTest(TestCase):
|
||||
"""
|
||||
Test riffusion.cli sample-clips
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def default_params() -> T.Dict:
|
||||
return dict(
|
||||
num_clips=3,
|
||||
duration_ms=5678,
|
||||
mono=False,
|
||||
extension="wav",
|
||||
seed=42,
|
||||
)
|
||||
|
||||
def test_sample_clips(self) -> None:
|
||||
"""
|
||||
Test sample-clips with default params.
|
||||
"""
|
||||
params = self.default_params()
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def test_mono(self) -> None:
|
||||
"""
|
||||
Test sample-clips with mono=True.
|
||||
"""
|
||||
params = self.default_params()
|
||||
params["mono"] = True
|
||||
params["num_clips"] = 1
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def test_mp3(self) -> None:
|
||||
"""
|
||||
Test sample-clips with extension=mp3.
|
||||
"""
|
||||
if pydub.AudioSegment.converter is None:
|
||||
self.skipTest("skipping, ffmpeg not found")
|
||||
|
||||
params = self.default_params()
|
||||
params["extension"] = "mp3"
|
||||
params["num_clips"] = 1
|
||||
self.helper_test_with_params(params)
|
||||
|
||||
def helper_test_with_params(self, params: T.Dict) -> None:
|
||||
"""
|
||||
Test sample-clips with the given params.
|
||||
"""
|
||||
audio_path = self.TEST_DATA_PATH / "tired_traveler" / "tired_traveler.mp3"
|
||||
output_dir = self.get_tmp_dir("sample_clips_")
|
||||
|
||||
sample_clips(
|
||||
audio=str(audio_path),
|
||||
output_dir=str(output_dir),
|
||||
**params,
|
||||
)
|
||||
|
||||
# For each file in output dir
|
||||
counter = 0
|
||||
for clip_path in output_dir.iterdir():
|
||||
# Check that it has the right extension
|
||||
self.assertEqual(clip_path.suffix, f".{params['extension']}")
|
||||
|
||||
# Check that it has the right duration
|
||||
segment = pydub.AudioSegment.from_file(clip_path)
|
||||
self.assertEqual(round(segment.duration_seconds * 1000), params["duration_ms"])
|
||||
|
||||
# Check that it has the right number of channels
|
||||
if params["mono"]:
|
||||
self.assertEqual(segment.channels, 1)
|
||||
else:
|
||||
self.assertEqual(segment.channels, 2)
|
||||
|
||||
counter += 1
|
||||
|
||||
self.assertEqual(counter, params["num_clips"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
TestCase.main()
|
|
@ -0,0 +1,86 @@
|
|||
import dataclasses
|
||||
import typing as T
|
||||
|
||||
import pydub
|
||||
|
||||
from riffusion.spectrogram_converter import SpectrogramConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import fft_util
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class SpectrogramConverterTest(TestCase):
|
||||
"""
|
||||
Test going from audio to spectrogram to audio, without converting to
|
||||
an image, to check quality loss of the reconstruction.
|
||||
|
||||
This test allows comparing multiple sets of spectrogram params by listening to output audio
|
||||
and by plotting their FFTs.
|
||||
"""
|
||||
|
||||
# TODO(hayk): Do an ablation of Griffin Lim and how much loss that introduces.
|
||||
|
||||
def test_round_trip(self) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
output_dir = self.get_tmp_dir(prefix="spectrogram_round_trip_test_")
|
||||
|
||||
# Load up the audio file
|
||||
segment = pydub.AudioSegment.from_file(audio_path)
|
||||
|
||||
# Convert to mono if desired
|
||||
use_stereo = False
|
||||
if use_stereo:
|
||||
assert segment.channels == 2
|
||||
else:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
# Define named sets of parameters
|
||||
param_sets: T.Dict[str, SpectrogramParams] = {}
|
||||
|
||||
param_sets["default"] = SpectrogramParams(
|
||||
sample_rate=segment.frame_rate,
|
||||
stereo=use_stereo,
|
||||
step_size_ms=10,
|
||||
min_frequency=20,
|
||||
max_frequency=20000,
|
||||
num_frequencies=512,
|
||||
)
|
||||
|
||||
if self.DEBUG:
|
||||
param_sets["freq_0_to_10k"] = dataclasses.replace(
|
||||
param_sets["default"],
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
segments: T.Dict[str, pydub.AudioSegment] = {
|
||||
"original": segment,
|
||||
}
|
||||
for name, params in param_sets.items():
|
||||
converter = SpectrogramConverter(params=params, device=self.DEVICE)
|
||||
spectrogram = converter.spectrogram_from_audio(segment)
|
||||
segments[name] = converter.audio_from_spectrogram(spectrogram, apply_filters=True)
|
||||
|
||||
# Save segments to disk
|
||||
for name, segment in segments.items():
|
||||
audio_out = output_dir / f"{name}.wav"
|
||||
segment.export(audio_out, format="wav")
|
||||
print(f"Saved {audio_out}")
|
||||
|
||||
# Check params
|
||||
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
|
||||
self.assertEqual(segments["original"].channels, segments["default"].channels)
|
||||
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
|
||||
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
|
||||
|
||||
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
|
||||
|
||||
# If debugging, load up a browser tab plotting the FFTs
|
||||
if self.DEBUG:
|
||||
fft_util.plot_ffts(segments)
|
|
@ -0,0 +1,97 @@
|
|||
import dataclasses
|
||||
import typing as T
|
||||
|
||||
import pydub
|
||||
from PIL import Image
|
||||
|
||||
from riffusion.spectrogram_image_converter import SpectrogramImageConverter
|
||||
from riffusion.spectrogram_params import SpectrogramParams
|
||||
from riffusion.util import fft_util
|
||||
|
||||
from .test_case import TestCase
|
||||
|
||||
|
||||
class SpectrogramImageConverterTest(TestCase):
|
||||
"""
|
||||
Test going from audio to spectrogram images to audio, testing the quality loss of the
|
||||
end-to-end pipeline.
|
||||
|
||||
This test allows comparing multiple sets of spectrogram params by listening to output audio
|
||||
and by plotting their FFTs.
|
||||
|
||||
See spectrogram_converter_test.py for a similar test that does not convert to images.
|
||||
"""
|
||||
|
||||
def test_round_trip(self) -> None:
|
||||
audio_path = (
|
||||
self.TEST_DATA_PATH
|
||||
/ "tired_traveler"
|
||||
/ "clips"
|
||||
/ "clip_2_start_103694_ms_duration_5678_ms.wav"
|
||||
)
|
||||
output_dir = self.get_tmp_dir(prefix="spectrogram_image_round_trip_test_")
|
||||
|
||||
# Load up the audio file
|
||||
segment = pydub.AudioSegment.from_file(audio_path)
|
||||
|
||||
# Convert to mono if desired
|
||||
use_stereo = False
|
||||
if use_stereo:
|
||||
assert segment.channels == 2
|
||||
else:
|
||||
segment = segment.set_channels(1)
|
||||
|
||||
# Define named sets of parameters
|
||||
param_sets: T.Dict[str, SpectrogramParams] = {}
|
||||
|
||||
param_sets["default"] = SpectrogramParams(
|
||||
sample_rate=segment.frame_rate,
|
||||
stereo=use_stereo,
|
||||
step_size_ms=10,
|
||||
min_frequency=20,
|
||||
max_frequency=20000,
|
||||
num_frequencies=512,
|
||||
)
|
||||
|
||||
if self.DEBUG:
|
||||
param_sets["freq_0_to_10k"] = dataclasses.replace(
|
||||
param_sets["default"],
|
||||
min_frequency=0,
|
||||
max_frequency=10000,
|
||||
)
|
||||
|
||||
segments: T.Dict[str, pydub.AudioSegment] = {
|
||||
"original": segment,
|
||||
}
|
||||
images: T.Dict[str, Image.Image] = {}
|
||||
for name, params in param_sets.items():
|
||||
converter = SpectrogramImageConverter(params=params, device=self.DEVICE)
|
||||
images[name] = converter.spectrogram_image_from_audio(segment)
|
||||
segments[name] = converter.audio_from_spectrogram_image(
|
||||
image=images[name],
|
||||
apply_filters=True,
|
||||
)
|
||||
|
||||
# Save images to disk
|
||||
for name, image in images.items():
|
||||
image_out = output_dir / f"{name}.png"
|
||||
image.save(image_out, exif=image.getexif(), format="PNG")
|
||||
print(f"Saved {image_out}")
|
||||
|
||||
# Save segments to disk
|
||||
for name, segment in segments.items():
|
||||
audio_out = output_dir / f"{name}.wav"
|
||||
segment.export(audio_out, format="wav")
|
||||
print(f"Saved {audio_out}")
|
||||
|
||||
# Check params
|
||||
self.assertEqual(segments["default"].channels, 2 if use_stereo else 1)
|
||||
self.assertEqual(segments["original"].channels, segments["default"].channels)
|
||||
self.assertEqual(segments["original"].frame_rate, segments["default"].frame_rate)
|
||||
self.assertEqual(segments["original"].sample_width, segments["default"].sample_width)
|
||||
|
||||
# TODO(hayk): Test something more rigorous about the quality of the reconstruction.
|
||||
|
||||
# If debugging, load up a browser tab plotting the FFTs
|
||||
if self.DEBUG:
|
||||
fft_util.plot_ffts(segments)
|
|
@ -0,0 +1,48 @@
|
|||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import typing as T
|
||||
import unittest
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class TestCase(unittest.TestCase):
|
||||
"""
|
||||
Base class for tests.
|
||||
"""
|
||||
|
||||
# Where checked-in test data is stored
|
||||
TEST_DATA_PATH = Path(__file__).resolve().parent / "test_data"
|
||||
|
||||
# Whether to run tests in debug mode (e.g. don't clean up temporary directories, show plots)
|
||||
DEBUG = bool(os.environ.get("RIFFUSION_TEST_DEBUG"))
|
||||
|
||||
# Which torch device to use for tests
|
||||
DEVICE = os.environ.get("RIFFUSION_TEST_DEVICE", "cuda")
|
||||
|
||||
@staticmethod
|
||||
def main(*args: T.Any, **kwargs: T.Any) -> None:
|
||||
"""
|
||||
Run the tests.
|
||||
"""
|
||||
unittest.main(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
warnings.filterwarnings("ignore", category=ResourceWarning)
|
||||
|
||||
def get_tmp_dir(self, prefix: str) -> Path:
|
||||
"""
|
||||
Create a temporary directory.
|
||||
"""
|
||||
tmp_dir = tempfile.mkdtemp(prefix=prefix)
|
||||
|
||||
# Clean up the temporary directory if not debugging
|
||||
if not self.DEBUG:
|
||||
self.addCleanup(lambda: shutil.rmtree(tmp_dir, ignore_errors=True))
|
||||
|
||||
dir_path = Path(tmp_dir)
|
||||
assert dir_path.is_dir()
|
||||
|
||||
return dir_path
|
|
@ -0,0 +1,7 @@
|
|||
# Test Data
|
||||
|
||||
### tired_traveler
|
||||
|
||||
* Song: Tired traveler on the way to home
|
||||
* Artist: Andrew Codeman
|
||||
* Source: https://freemusicarchive.org/
|
After Width: | Height: | Size: 258 KiB |
After Width: | Height: | Size: 382 KiB |