Update python project configuration files

* Adds a pyproject.toml
* Update requirements and dev requirements
* Add a CITATION file
* Add details to the README

Topic: clean_rewrite
This commit is contained in:
Hayk Martiros 2022-12-26 17:26:46 -08:00
parent cbf473216b
commit 52bec9575b
6 changed files with 155 additions and 5 deletions

6
.gitignore vendored
View File

@ -6,6 +6,9 @@ __pycache__/
# C extensions
*.so
# VSCode
.vscode
# Distribution / packaging
.Python
build/
@ -27,6 +30,9 @@ share/python-wheels/
*.egg
MANIFEST
# OSX cruft
.DS_Store
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.

6
CITATION Normal file
View File

@ -0,0 +1,6 @@
@article{Forsgren_Martiros_2022,
author = {Forsgren, Seth* and Martiros, Hayk*},
title = {{Riffusion - Stable diffusion for real-time music generation}},
url = {https://riffusion.com/about},
year = {2022}
}

View File

@ -32,14 +32,16 @@ python -m pip install -r requirements.txt
If torchaudio has no audio backend, see [this issue](https://github.com/riffusion/riffusion/issues/12).
You can open and save WAV files with pure python. For opening and saving non-wav files like mp3 you'll need ffmpeg or libav.
Guides:
* [CUDA help](https://github.com/riffusion/riffusion/issues/3)
* [Windows Simple Instructions](https://www.reddit.com/r/riffusion/comments/zrubc9/installation_guide_for_riffusion_app_inference/)
## Run
## Run the model server
Start the Flask server:
```
python -m riffusion.server --port 3013 --host 127.0.0.1
python -m riffusion.server --host 127.0.0.1 --port 3013
```
You can specify `--checkpoint` with your own directory or huggingface ID in diffusers format.
@ -77,6 +79,52 @@ Example output (see [InferenceOutput](https://github.com/hmartiro/riffusion-infe
}
```
Use the `--device` argument to specify the torch device to use.
`cuda` is recommended.
`cpu` works but is quite slow.
`mps` is supported for inference, but some operations fall back to CPU. You may need to set
PYTORCH_ENABLE_MPS_FALLBACK=1. In addition, it is not deterministic.
## Test
Tests live in the `test/` directory and are implemented with `unittest`.
To run all tests:
```
python -m unittest test/*_test.py
```
To run a single test:
```
python -m unittest test.audio_to_image_test
```
To preserve temporary outputs for debugging, set `RIFFUSION_TEST_DEBUG`:
```
RIFFUSION_TEST_DEBUG=1 python -m unittest test.audio_to_image_test
```
To run a single test case:
```
python -m unittest test.audio_to_image_test -k AudioToImageTest.test_stereo
```
To run tests using a specific torch device, set `RIFFUSION_TEST_DEVICE`. Tests should pass with
`cpu`, `cuda`, and `mps` backends.
## Development
Install additional packages for dev with `pip install -r dev_requirements.txt`.
* Linter: `ruff`
* Formatter: `black`
* Type checker: `mypy`
These are configured in `pyproject.toml`.
The results of `mypy .`, `black .`, and `ruff .` *must* be clean to accept a PR.
## Citation
If you build on this work, please cite it as follows:

View File

@ -1,6 +1,7 @@
black
ipdb
isort
mypy
pylint
ruff
types-Flask-Cors
types-Pillow
types-requests

87
pyproject.toml Normal file
View File

@ -0,0 +1,87 @@
[tool.black]
line-length = 100
[tool.ruff]
line-length = 100
# Which rules to run
select = [
# Pyflakes
"F",
# Pycodestyle
"E",
"W",
# isort
# "I001"
]
ignore = []
# Exclude a variety of commonly ignored directories.
exclude = [
".bzr",
".direnv",
".eggs",
".git",
".hg",
".mypy_cache",
".nox",
".pants.d",
".ruff_cache",
".svn",
".tox",
".venv",
"__pypackages__",
"_build",
"buck-out",
"build",
"dist",
"node_modules",
"venv",
]
per-file-ignores = {}
# Allow unused variables when underscore-prefixed.
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
# Assume Python 3.10.
target-version = "py310"
[tool.ruff.mccabe]
# Unlike Flake8, default to a complexity level of 10.
max-complexity = 10
[tool.mypy]
python_version = "3.10"
[[tool.mypy.overrides]]
module = "argh.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "diffusers.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "plotly.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "pydub.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "scipy.fft.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "scipy.io.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "torchaudio.*"
ignore_missing_imports = true
[[tool.mypy.overrides]]
module = "transformers.*"
ignore_missing_imports = true

View File

@ -6,9 +6,11 @@ flask
flask_cors
numpy
pillow
plotly
pydub
scipy
soundfile
streamlit
torch
torchaudio
transformers