added test workflow and fixed failing test (#237)

* added test workflow and fixed failing test

* 4 decimal places
This commit is contained in:
Kashif Rasul 2022-08-24 07:46:53 -04:00 committed by GitHub
parent 102cabeb23
commit 47893164ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 2 deletions

38
.github/workflows/test.yml vendored Normal file
View File

@ -0,0 +1,38 @@
name: Run Tests
on:
push:
branches: [ $default-branch ]
pull_request:
branches: [ $default-branch ]
jobs:
build:
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
cache: 'pip' # caching pip dependencies
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install -e .
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest

View File

@ -567,8 +567,8 @@ class PNDMSchedulerTest(SchedulerCommonTest):
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 199.1169) < 1e-2
assert abs(result_mean.item() - 0.2593) < 1e-3
assert abs(result_sum.item() - 428.8788) < 1e-2
assert abs(result_mean.item() - 0.5584) < 1e-3
class ScoreSdeVeSchedulerTest(unittest.TestCase):