From a407c9f0147c779865c940cbf62c7019dbc1f7b4 Mon Sep 17 00:00:00 2001 From: DaniAndTheWeb <57776841+DaniAndTheWeb@users.noreply.github.com> Date: Fri, 13 Jan 2023 19:22:23 +0100 Subject: [PATCH] Automatic torch install for amd on linux This commit allows the launch script to automatically download rocm's torch version for AMD GPUs using an external GPU detection script. It also prints the operative system and GPU in use. --- launch.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/launch.py b/launch.py index bcbb792ca..668548f13 100644 --- a/launch.py +++ b/launch.py @@ -7,6 +7,7 @@ import shlex import platform import argparse import json +import detection dir_repos = "repositories" dir_extensions = "extensions" @@ -15,6 +16,12 @@ git = os.environ.get('GIT', "git") index_url = os.environ.get('INDEX_URL', "") stored_commit_hash = None +# Get the GPU vendor and the operating system +gpu = detection.check_gpu() +if os.name == "posix": + os_name = platform.uname().system +else: + os_name = os.name def commit_hash(): global stored_commit_hash @@ -173,7 +180,11 @@ def run_extensions_installers(settings_file): def prepare_environment(): - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") + if gpu == "AMD" and os_name !="nt": + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2") + else: + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113") + requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") commandline_args = os.environ.get('COMMANDLINE_ARGS', "") @@ -295,6 +306,8 @@ def tests(test_dir): def start(): + print(f"Operating System: {os_name}") + print(f"GPU: {gpu}") print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}") import webui if '--nowebui' in sys.argv: