Add NPU Support
This commit is contained in:
parent
cf2772fab0
commit
ec124607f4
|
@ -3,7 +3,7 @@ import contextlib
|
|||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from modules import errors, shared
|
||||
from modules import errors, shared, npu_specific
|
||||
|
||||
if sys.platform == "darwin":
|
||||
from modules import mac_specific
|
||||
|
@ -40,6 +40,9 @@ def get_optimal_device_name():
|
|||
if has_xpu():
|
||||
return xpu_specific.get_xpu_device_string()
|
||||
|
||||
if npu_specific.has_npu:
|
||||
return npu_specific.get_npu_device_string()
|
||||
|
||||
return "cpu"
|
||||
|
||||
|
||||
|
@ -67,6 +70,9 @@ def torch_gc():
|
|||
if has_xpu():
|
||||
xpu_specific.torch_xpu_gc()
|
||||
|
||||
if npu_specific.has_npu:
|
||||
npu_specific.torch_npu_gc()
|
||||
|
||||
|
||||
def enable_tf32():
|
||||
if torch.cuda.is_available():
|
||||
|
@ -164,4 +170,3 @@ def first_time_calculation():
|
|||
x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
|
||||
conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
|
||||
conv2d(x)
|
||||
|
||||
|
|
|
@ -143,13 +143,17 @@ def initialize_rest(*, reload_script_modules=False):
|
|||
its optimization may be None because the list of optimizaers has neet been filled
|
||||
by that time, so we apply optimization again.
|
||||
"""
|
||||
from modules import devices
|
||||
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
|
||||
if devices.npu_specific.has_npu:
|
||||
import torch
|
||||
torch.npu.set_device(0)
|
||||
|
||||
shared.sd_model # noqa: B018
|
||||
|
||||
if sd_hijack.current_optimizer is None:
|
||||
sd_hijack.apply_optimizations()
|
||||
|
||||
from modules import devices
|
||||
devices.first_time_calculation()
|
||||
if not shared.cmd_opts.skip_load_model_at_start:
|
||||
Thread(target=load_model).start()
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
import importlib
|
||||
import torch
|
||||
|
||||
from modules import shared
|
||||
|
||||
|
||||
def check_for_npu():
|
||||
if importlib.util.find_spec("torch_npu") is None:
|
||||
return False
|
||||
import torch_npu
|
||||
torch_npu.npu.set_device(0)
|
||||
|
||||
try:
|
||||
# Will raise a RuntimeError if no NPU is found
|
||||
_ = torch.npu.device_count()
|
||||
return torch.npu.is_available()
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
|
||||
def get_npu_device_string():
|
||||
if shared.cmd_opts.device_id is not None:
|
||||
return f"npu:{shared.cmd_opts.device_id}"
|
||||
return "npu:0"
|
||||
|
||||
|
||||
def torch_npu_gc():
|
||||
# Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
|
||||
torch.npu.set_device(0)
|
||||
with torch.npu.device(get_npu_device_string()):
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
has_npu = check_for_npu()
|
|
@ -151,6 +151,10 @@ class EmbeddingDatabase:
|
|||
return embedding
|
||||
|
||||
def get_expected_shape(self):
|
||||
# workaround
|
||||
if devices.npu_specific.has_npu:
|
||||
import torch
|
||||
torch.npu.set_device(0)
|
||||
vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
|
||||
return vec.shape[1]
|
||||
|
||||
|
|
|
@ -5,6 +5,8 @@ accelerate
|
|||
basicsr
|
||||
blendmodes
|
||||
clean-fid
|
||||
cloudpickle
|
||||
decorator
|
||||
einops
|
||||
fastapi>=0.90.1
|
||||
gfpgan
|
||||
|
@ -26,9 +28,11 @@ resize-right
|
|||
|
||||
safetensors
|
||||
scikit-image>=0.19
|
||||
synr==0.5.0
|
||||
timm
|
||||
tomesd
|
||||
torch
|
||||
torchdiffeq
|
||||
torchsde
|
||||
tornado
|
||||
transformers==4.30.2
|
||||
|
|
|
@ -4,6 +4,8 @@ accelerate==0.21.0
|
|||
basicsr==1.4.2
|
||||
blendmodes==2022
|
||||
clean-fid==0.1.35
|
||||
cloudpickle==3.0.0
|
||||
decorator==5.1.1
|
||||
einops==0.4.1
|
||||
fastapi==0.94.0
|
||||
gfpgan==1.3.8
|
||||
|
@ -23,10 +25,12 @@ realesrgan==0.3.0
|
|||
resize-right==0.0.2
|
||||
safetensors==0.3.1
|
||||
scikit-image==0.21.0
|
||||
synr==0.5.0
|
||||
timm==0.9.2
|
||||
tomesd==0.1.3
|
||||
torch
|
||||
torchdiffeq==0.2.3
|
||||
torchsde==0.2.6
|
||||
tornado==6.4
|
||||
transformers==4.30.2
|
||||
httpx==0.24.1
|
||||
|
|
4
webui.sh
4
webui.sh
|
@ -159,6 +159,10 @@ then
|
|||
if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]]
|
||||
then
|
||||
export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2"
|
||||
elif echo "$gpu_info" | grep -q "Huawei" && [[ -z "${TORCH_COMMAND}" ]]
|
||||
then
|
||||
export TORCH_COMMAND="pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu; pip install torch_npu"
|
||||
|
||||
fi
|
||||
fi
|
||||
|
||||
|
|
Loading…
Reference in New Issue