From 6ae04672b660bd3682798be44e6fd92a7ccd3d7a Mon Sep 17 00:00:00 2001 From: Hugo Larcher Date: Thu, 3 Oct 2024 11:34:38 +0200 Subject: [PATCH] fix: Add variable results file location --- .github/workflows/load_test.yaml | 2 +- load_tests/benchmarks.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml index 0f9407d8..4c212e08 100644 --- a/.github/workflows/load_test.yaml +++ b/.github/workflows/load_test.yaml @@ -44,7 +44,7 @@ jobs: export PATH="$HOME/.local/bin:$PATH" cd load_tests poetry install - poetry run python benchmarks.py + poetry run python benchmarks.py --sha ${{ github.sha }} --results-file "s3://text-generation-inference-ci/benchmarks/ci/${{ github.sha }}.parquet" shell: bash env: HF_TOKEN: ${{ secrets.HF_TOKEN_BENCHMARK }} diff --git a/load_tests/benchmarks.py b/load_tests/benchmarks.py index 64e5d728..84635590 100644 --- a/load_tests/benchmarks.py +++ b/load_tests/benchmarks.py @@ -1,3 +1,4 @@ +import argparse import datetime import json import os @@ -162,7 +163,7 @@ def build_df(model: str, data_files: dict[str, str]) -> pd.DataFrame: return df -def main(): +def main(sha, results_file): results_dir = 'results' # get absolute path results_dir = os.path.join(os.path.dirname(__file__), results_dir) @@ -172,7 +173,6 @@ def main(): # ('meta-llama/Llama-3.1-70B-Instruct', 4), # ('mistralai/Mixtral-8x7B-Instruct-v0.1', 2), ] - sha = os.environ.get('GITHUB_SHA') success = True for model in models: tgi_runner = TGIDockerRunner(model[0]) @@ -225,8 +225,18 @@ def main(): df = pd.concat([df, build_df(directory.split('/')[-1], data_files)]) df['device'] = get_gpu_name() df['error_rate'] = df['failed_requests'] / (df['failed_requests'] + df['successful_requests']) * 100.0 - df.to_parquet(f's3://text-generation-inference-ci/benchmarks/ci/{sha}.parquet') + df.to_parquet(results_file) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--sha", help="SHA of the commit to add to the results", required=True) + parser.add_argument("--results-file", + help="The file where to store the results, can be a local file or a s3 path") + args = parser.parse_args() + if args.results_file is None: + results_file = f'{args.sha}.parquet' + else: + results_file = args.results_file + + main(args.sha, results_file)