diff --git a/.gitignore b/.gitignore index b6e4761..2b82e86 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ __pycache__/ # C extensions *.so +# VSCode +.vscode + # Distribution / packaging .Python build/ @@ -27,6 +30,9 @@ share/python-wheels/ *.egg MANIFEST +# OSX cruft +.DS_Store + # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. diff --git a/CITATION b/CITATION new file mode 100644 index 0000000..9f4ae3a --- /dev/null +++ b/CITATION @@ -0,0 +1,6 @@ +@article{Forsgren_Martiros_2022, + author = {Forsgren, Seth* and Martiros, Hayk*}, + title = {{Riffusion - Stable diffusion for real-time music generation}}, + url = {https://riffusion.com/about}, + year = {2022} +} diff --git a/README.md b/README.md index 2fd615a..30864df 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ Tested with Python 3.9 and diffusers 0.9.0. To run this model, you need a GPU with CUDA. To run it in real time, it needs to be able to run stable diffusion with approximately 50 steps in under five seconds. -You need to make sure you have torch and torchaudio installed with CUDA support. See the [install guide](https://pytorch.org/get-started/locally/) or [stable wheels](https://download.pytorch.org/whl/torch_stable.html). +You need to make sure you have torch and torchaudio installed with CUDA support. See the [install guide](https://pytorch.org/get-started/locally/) or [stable wheels](https://download.pytorch.org/whl/torch_stable.html). ``` conda create --name riffusion-inference python=3.9 @@ -32,14 +32,16 @@ python -m pip install -r requirements.txt If torchaudio has no audio backend, see [this issue](https://github.com/riffusion/riffusion/issues/12). +You can open and save WAV files with pure python. For opening and saving non-wav files – like mp3 – you'll need ffmpeg or libav. + Guides: * [CUDA help](https://github.com/riffusion/riffusion/issues/3) * [Windows Simple Instructions](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/) -## Run +## Run the model server Start the Flask server: ``` -python -m riffusion.server --port 3013 --host 127.0.0.1 +python -m riffusion.server --host 127.0.0.1 --port 3013 ``` You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format. @@ -77,6 +79,52 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe } ``` +Use the `--device` argument to specify the torch device to use. + +`cuda` is recommended. + +`cpu` works but is quite slow. + +`mps` is supported for inference, but some operations fall back to CPU. You may need to set +PYTORCH_ENABLE_MPS_FALLBACK=1. In addition, it is not deterministic. + +## Test +Tests live in the `test/` directory and are implemented with `unittest`. + +To run all tests: +``` +python -m unittest test/*_test.py +``` + +To run a single test: +``` +python -m unittest test.audio_to_image_test +``` + +To preserve temporary outputs for debugging, set `RIFFUSION_TEST_DEBUG`: +``` +RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test +``` + +To run a single test case: +``` +python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo +``` + +To run tests using a specific torch device, set `RIFFUSION_TEST_DEVICE`. Tests should pass with +`cpu`, `cuda`, and `mps` backends. + +## Development +Install additional packages for dev with `pip install -r dev_requirements.txt`. + +* Linter: `ruff` +* Formatter: `black` +* Type checker: `mypy` + +These are configured in `pyproject.toml`. + +The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR. + ## Citation If you build on this work, please cite it as follows: diff --git a/dev_requirements.txt b/dev_requirements.txt index b10ce3e..a37435b 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -1,6 +1,7 @@ black ipdb -isort mypy -pylint +ruff +types-Flask-Cors +types-Pillow types-requests diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c6fb776 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,87 @@ +[tool.black] +line-length = 100 + +[tool.ruff] +line-length = 100 + +# Which rules to run +select = [ + # Pyflakes + "F", + # Pycodestyle + "E", + "W", + # isort + # "I001" +] + +ignore = [] + +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", +] +per-file-ignores = {} + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +# Assume Python 3.10. +target-version = "py310" + +[tool.ruff.mccabe] +# Unlike Flake8, default to a complexity level of 10. +max-complexity = 10 + +[tool.mypy] +python_version = "3.10" + +[[tool.mypy.overrides]] +module = "argh.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "diffusers.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "plotly.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "pydub.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "scipy.fft.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "scipy.io.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "torchaudio.*" +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "transformers.*" +ignore_missing_imports = true diff --git a/requirements.txt b/requirements.txt index 57edf91..84aa270 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,9 +6,11 @@ flask flask_cors numpy pillow +plotly pydub scipy soundfile +streamlit torch torchaudio transformers