add ascend npu support for TGI

This commit is contained in:
statelesshz 2024-04-14 16:11:10 +08:00
parent 7dbaf9e901
commit aba7072fd5
6 changed files with 41 additions and 2 deletions

View File

@ -7,14 +7,17 @@ pub(crate) struct Env {
git_sha: &'static str,
docker_label: &'static str,
nvidia_env: String,
npu_env: String,
}
impl Env {
pub fn new() -> Self {
let nvidia_env = nvidia_smi();
let npu_env = npu_smi();
Self {
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
npu_env: npu_env.unwrap_or("N/A".to_string()),
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
@ -31,7 +34,8 @@ impl fmt::Display for Env {
writeln!(f, "Cargo version: {}", self.cargo_version)?;
writeln!(f, "Commit sha: {}", self.git_sha)?;
writeln!(f, "Docker label: {}", self.docker_label)?;
write!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
write!(f, "npu-smi:\n{}", self.npu_env)?;
Ok(())
}
@ -43,3 +47,10 @@ fn nvidia_smi() -> Option<String> {
let output = nvidia_smi.replace('\n', "\n ");
Some(output.trim().to_string())
}
fn npu_smi() -> Option<String> {
let output = Command::new("npu-smi info").output().ok()?;
let npu_smi = String::from_utf8(output.stdout).ok()?;
let output = npu_smi.replace('\n', "\n ");
Some(output.trim().to_string())
}

View File

@ -3,6 +3,7 @@ import torch
from typing import Dict, Optional, TypeVar
from text_generation_server.models.types import Batch
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
B = TypeVar("B", bound=Batch)
@ -24,6 +25,8 @@ class Cache:
del batch
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif IS_NPU_SYSTEM:
torch.npu.empty_cache()
def clear(self):
keys = list(self.cache.keys())

View File

@ -6,6 +6,7 @@ from grpc_status import rpc_status
from grpc_interceptor.server import AsyncServerInterceptor
from loguru import logger
from typing import Callable, Any
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
class ExceptionInterceptor(AsyncServerInterceptor):
@ -25,6 +26,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
if torch.cuda.is_available():
torch.cuda.empty_cache()
elif IS_NPU_SYSTEM:
torch.npu.empty_cache()
await context.abort_with_status(
rpc_status.to_status(

View File

@ -1,6 +1,6 @@
import torch
import os
MEM_POOL = torch.cuda.graph_pool_handle()
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
# This is overridden by the cli
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}

View File

@ -3,12 +3,14 @@ import torch
from datetime import timedelta
from loguru import logger
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
# Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
# CUDA memory fraction
# TODO: Do we need to rename CUDA_MEMORY_FRACTION to DEVICE_MEMORY_FRACTION?
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
@ -56,6 +58,15 @@ def initialize_torch_distributed():
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
elif IS_NPU_SYSTEM:
assert WORLD_SIZE <= torch.npu.device_count(), "Each process is one npu"
device = RANK % torch.npu.device_count()
torch.npu.set_device(device)
torch.npu.set_per_process_memory_fraction(MEMORY_FRACTION, device)
backend = "hccl"
options = ProcessGroupNCCL.Options()
options.is_high_priority_stream = True
options._timeout = timedelta(seconds=60)
else:
backend = "gloo"
options = None

View File

@ -1,4 +1,15 @@
import torch
def is_npu_available():
try:
import torch_npu # noqa: F401
except ImportError:
return False
return hasattr(torch, "npu") and torch.npu.is_available()
IS_ROCM_SYSTEM = torch.version.hip is not None
IS_CUDA_SYSTEM = torch.version.cuda is not None
IS_NPU_SYSTEM = is_npu_available()