add ascend npu support for TGI
This commit is contained in:
parent
7dbaf9e901
commit
aba7072fd5
|
@ -7,14 +7,17 @@ pub(crate) struct Env {
|
|||
git_sha: &'static str,
|
||||
docker_label: &'static str,
|
||||
nvidia_env: String,
|
||||
npu_env: String,
|
||||
}
|
||||
|
||||
impl Env {
|
||||
pub fn new() -> Self {
|
||||
let nvidia_env = nvidia_smi();
|
||||
let npu_env = npu_smi();
|
||||
|
||||
Self {
|
||||
nvidia_env: nvidia_env.unwrap_or("N/A".to_string()),
|
||||
npu_env: npu_env.unwrap_or("N/A".to_string()),
|
||||
cargo_target: env!("VERGEN_CARGO_TARGET_TRIPLE"),
|
||||
cargo_version: env!("VERGEN_RUSTC_SEMVER"),
|
||||
git_sha: option_env!("VERGEN_GIT_SHA").unwrap_or("N/A"),
|
||||
|
@ -31,7 +34,8 @@ impl fmt::Display for Env {
|
|||
writeln!(f, "Cargo version: {}", self.cargo_version)?;
|
||||
writeln!(f, "Commit sha: {}", self.git_sha)?;
|
||||
writeln!(f, "Docker label: {}", self.docker_label)?;
|
||||
write!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
||||
writeln!(f, "nvidia-smi:\n{}", self.nvidia_env)?;
|
||||
write!(f, "npu-smi:\n{}", self.npu_env)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -43,3 +47,10 @@ fn nvidia_smi() -> Option<String> {
|
|||
let output = nvidia_smi.replace('\n', "\n ");
|
||||
Some(output.trim().to_string())
|
||||
}
|
||||
|
||||
fn npu_smi() -> Option<String> {
|
||||
let output = Command::new("npu-smi info").output().ok()?;
|
||||
let npu_smi = String::from_utf8(output.stdout).ok()?;
|
||||
let output = npu_smi.replace('\n', "\n ");
|
||||
Some(output.trim().to_string())
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
|||
from typing import Dict, Optional, TypeVar
|
||||
|
||||
from text_generation_server.models.types import Batch
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
B = TypeVar("B", bound=Batch)
|
||||
|
||||
|
@ -24,6 +25,8 @@ class Cache:
|
|||
del batch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif IS_NPU_SYSTEM:
|
||||
torch.npu.empty_cache()
|
||||
|
||||
def clear(self):
|
||||
keys = list(self.cache.keys())
|
||||
|
|
|
@ -6,6 +6,7 @@ from grpc_status import rpc_status
|
|||
from grpc_interceptor.server import AsyncServerInterceptor
|
||||
from loguru import logger
|
||||
from typing import Callable, Any
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
|
||||
class ExceptionInterceptor(AsyncServerInterceptor):
|
||||
|
@ -25,6 +26,8 @@ class ExceptionInterceptor(AsyncServerInterceptor):
|
|||
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
elif IS_NPU_SYSTEM:
|
||||
torch.npu.empty_cache()
|
||||
|
||||
await context.abort_with_status(
|
||||
rpc_status.to_status(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import os
|
||||
|
||||
MEM_POOL = torch.cuda.graph_pool_handle()
|
||||
MEM_POOL = torch.cuda.graph_pool_handle() if torch.cuda.is_available() else None
|
||||
# This is overridden by the cli
|
||||
ENABLE_CUDA_GRAPHS = os.getenv("ENABLE_CUDA_GRAPHS", "false").lower() in {"1", "true"}
|
||||
|
|
|
@ -3,12 +3,14 @@ import torch
|
|||
|
||||
from datetime import timedelta
|
||||
from loguru import logger
|
||||
from text_generation_server.utils.import_utils import IS_NPU_SYSTEM
|
||||
|
||||
# Tensor Parallelism settings
|
||||
RANK = int(os.getenv("RANK", "0"))
|
||||
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
# CUDA memory fraction
|
||||
# TODO: Do we need to rename CUDA_MEMORY_FRACTION to DEVICE_MEMORY_FRACTION?
|
||||
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
|
||||
|
||||
|
||||
|
@ -56,6 +58,15 @@ def initialize_torch_distributed():
|
|||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=60)
|
||||
elif IS_NPU_SYSTEM:
|
||||
assert WORLD_SIZE <= torch.npu.device_count(), "Each process is one npu"
|
||||
device = RANK % torch.npu.device_count()
|
||||
torch.npu.set_device(device)
|
||||
torch.npu.set_per_process_memory_fraction(MEMORY_FRACTION, device)
|
||||
backend = "hccl"
|
||||
options = ProcessGroupNCCL.Options()
|
||||
options.is_high_priority_stream = True
|
||||
options._timeout = timedelta(seconds=60)
|
||||
else:
|
||||
backend = "gloo"
|
||||
options = None
|
||||
|
|
|
@ -1,4 +1,15 @@
|
|||
import torch
|
||||
|
||||
|
||||
def is_npu_available():
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
return hasattr(torch, "npu") and torch.npu.is_available()
|
||||
|
||||
|
||||
IS_ROCM_SYSTEM = torch.version.hip is not None
|
||||
IS_CUDA_SYSTEM = torch.version.cuda is not None
|
||||
IS_NPU_SYSTEM = is_npu_available()
|
||||
|
|
Loading…
Reference in New Issue