From b64c70c9e7a2a416117cb6b317cb85e5d679717a Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Tue, 25 Jun 2024 18:21:29 +0800 Subject: [PATCH] Cpu tgi (#1936) * add CPU tgi support Signed-off-by: Wang, Yi A * ipex distributed ops support Signed-off-by: Wang, Yi A --------- Signed-off-by: Wang, Yi A Co-authored-by: Funtowicz Morgan --- Dockerfile_intel | 87 ++++++++++++++++++- .../layers/attention/__init__.py | 4 +- .../layers/attention/xpu.py | 5 +- .../layers/layernorm.py | 24 ++--- .../text_generation_server/layers/rotary.py | 6 +- .../layers/tensor_parallel.py | 31 +++++-- .../custom_modeling/flash_dbrx_modeling.py | 4 +- .../custom_modeling/flash_mixtral_modeling.py | 4 +- .../models/flash_causal_lm.py | 49 +++++++---- .../models/flash_gpt2.py | 5 +- .../models/flash_llama.py | 5 +- .../models/flash_mistral.py | 5 +- .../models/flash_neox.py | 5 +- .../text_generation_server/models/flash_rw.py | 5 +- .../models/flash_santacoder.py | 5 +- server/text_generation_server/utils/dist.py | 35 ++++---- .../utils/import_utils.py | 19 ++-- 17 files changed, 221 insertions(+), 77 deletions(-) diff --git a/Dockerfile_intel b/Dockerfile_intel index f09614d4..a41fbc1e 100644 --- a/Dockerfile_intel +++ b/Dockerfile_intel @@ -1,3 +1,5 @@ +ARG PLATFORM=xpu + FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef WORKDIR /usr/src @@ -37,7 +39,8 @@ RUN cargo build --profile release-opt # Text Generation Inference base image for Intel -FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base + +FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu USER root # libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it @@ -59,7 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \ WORKDIR /usr/src RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl -RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed # Install server COPY proto proto @@ -89,8 +92,84 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca # Install launcher COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher -# Final image -FROM base +# Text Generation Inference base image for Intel-cpu +FROM ubuntu:22.04 as cpu + +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ + curl \ + ca-certificates \ + make \ + g++ \ + git \ + wget \ + cmake + +ENV HUGGINGFACE_HUB_CACHE=/data \ + HF_HUB_ENABLE_HF_TRANSFER=1 \ + PORT=80 + +ARG MAMBA_VERSION=23.1.0-1 +ARG PYTHON_VERSION='3.10.10' +# Automatically set by buildx +ARG TARGETPLATFORM +ENV PATH /opt/conda/bin:$PATH + +# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda. +# Install mamba +# translating Docker's TARGETPLATFORM into mamba arches +RUN case ${TARGETPLATFORM} in \ + "linux/arm64") MAMBA_ARCH=aarch64 ;; \ + *) MAMBA_ARCH=x86_64 ;; \ + esac && \ + curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh" +RUN chmod +x ~/mambaforge.sh && \ + bash ~/mambaforge.sh -b -p /opt/conda && \ + rm ~/mambaforge.sh + +RUN conda install -c conda-forge gperftools mkl + +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl +RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl + +WORKDIR /usr/src + +RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a + +RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131 + +RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install + +RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install . + +ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so +ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch +ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch +ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric +ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib +ENV KMP_BLOCKTIME=1 +ENV KMP_TPAUSE=0 +ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist +ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist +ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist + +# Install server +COPY proto proto +COPY server server +COPY server/Makefile server/Makefile +RUN cd server && \ + make gen-server && \ + pip install -r requirements_intel.txt && \ + pip install ".[accelerate, peft, outlines]" --no-cache-dir + +# Install benchmarker +COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark +# Install router +COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router +# Install launcher +COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher + +FROM ${PLATFORM} as final ENTRYPOINT ["text-generation-launcher"] CMD ["--json-output"] diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index e6cb4edf..a53c8e3b 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,4 +1,4 @@ -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL import os if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": @@ -7,7 +7,7 @@ if SYSTEM == "cuda": from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING -elif SYSTEM == "xpu": +elif IPEX_AVAIL: from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/xpu.py index 8b6cb87b..bfab0119 100644 --- a/server/text_generation_server/layers/attention/xpu.py +++ b/server/text_generation_server/layers/attention/xpu.py @@ -1,5 +1,6 @@ import intel_extension_for_pytorch as ipex import torch +from text_generation_server.models.flash_causal_lm import BLOCK_SIZE SUPPORTS_WINDOWING = False @@ -56,8 +57,6 @@ def paged_attention( input_lengths: torch.Tensor, max_s: int, ): - query = query.contiguous() - block_size = value_cache.shape[3] return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention( out, query, @@ -67,7 +66,7 @@ def paged_attention( softmax_scale, block_tables, input_lengths, - block_size, + BLOCK_SIZE, max_s, None, ) diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index c4aa6c7d..93e83dfa 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -3,6 +3,7 @@ from torch import nn from accelerate import init_empty_weights from text_generation_server.utils.import_utils import ( SYSTEM, + IPEX_AVAIL, ) @@ -82,18 +83,20 @@ elif SYSTEM == "rocm": return super().forward(hidden_states), residual -elif SYSTEM == "xpu": +elif IPEX_AVAIL: import intel_extension_for_pytorch as ipex class FastLayerNorm(nn.LayerNorm): def forward(self, hidden_states, residual=None): - res_out = hidden_states out = ipex.llm.functional.add_layer_norm( - residual, hidden_states, self.weight, self.bias, self.eps, True + residual, + hidden_states, + self.weight, + self.bias, + self.eps, + residual is not None, ) - if residual is not None: - res_out = residual - return out, res_out + return out, residual if residual is not None else hidden_states class FastRMSNorm(nn.Module): @@ -109,19 +112,16 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if SYSTEM == "xpu": - residual_out = hidden_states + if IPEX_AVAIL: out = ipex.llm.functional.add_rms_norm( residual, hidden_states, self.weight, None, self.variance_epsilon, - True, + residual is not None, ) - if residual is not None: - residual_out = residual - return out, residual_out + return out, residual if residual is not None else hidden_states elif hidden_states.shape[-1] > 8192: if residual is not None: hidden_states += residual diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index c2f12189..1892cf69 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -2,14 +2,14 @@ import os import torch from torch import nn -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL if SYSTEM == "cuda": from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif SYSTEM == "rocm": from vllm._C import ops -elif SYSTEM == "xpu": +elif IPEX_AVAIL: import intel_extension_for_pytorch as ipex @@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module): # Inplace operation, updating query and key. ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif SYSTEM == "xpu": + elif IPEX_AVAIL: ipex.llm.functional.rotary_embedding( query, key, sin, cos, query.size(-1), True ) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 6005f737..510dc2c6 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -3,6 +3,10 @@ from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.exl2 import Exl2Weight +from text_generation_server.utils.import_utils import IPEX_AVAIL + +if IPEX_AVAIL: + import intel_extension_for_pytorch as ipex class LayerConcat(torch.nn.Module): @@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer): local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) - - torch.distributed.all_gather_into_tensor( - world_out, gather_input, group=self.process_group - ) + if IPEX_AVAIL: + ipex.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) + else: + torch.distributed.all_gather_into_tensor( + world_out, gather_input, group=self.process_group + ) if input.shape[0] == 1: return world_out @@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer): world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] - torch.distributed.all_gather(world_output, output, group=self.process_group) + if IPEX_AVAIL: + ipex.distributed.all_gather(world_output, output, group=self.process_group) + else: + torch.distributed.all_gather(world_output, output, group=self.process_group) world_output = torch.cat(world_output, dim=-1) return world_output @@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: - torch.distributed.all_reduce(out, group=self.process_group) + if IPEX_AVAIL: + ipex.distributed.all_reduce(out, group=self.process_group) + else: + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - torch.distributed.all_reduce(out, group=self.process_group) + if IPEX_AVAIL: + ipex.distributed.all_reduce(out, group=self.process_group) + else: + torch.distributed.all_reduce(out, group=self.process_group) return out diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 94cf6452..56292250 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,9 +20,9 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import IPEX_AVAIL -if SYSTEM != "xpu": +if not IPEX_AVAIL: from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.layers.attention import ( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 3900bf73..6ea95411 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,9 +24,9 @@ import torch.distributed import numpy as np from torch import nn -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import IPEX_AVAIL -if SYSTEM != "xpu": +if not IPEX_AVAIL: from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index d16d3710..633b066b 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -15,7 +15,7 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK @@ -773,21 +773,38 @@ class FlashCausalLM(Model): else: x = BLOCK_SIZE // element_size - self.kv_cache = [ - ( - torch.empty( - (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), - dtype=dtype, - device=device, - ), - torch.empty( - (num_blocks, num_heads, head_size, BLOCK_SIZE), - dtype=dtype, - device=device, - ), - ) - for _ in range(num_layers) - ] + if IPEX_AVAIL and SYSTEM == "cpu": + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, BLOCK_SIZE, head_size), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] + else: + self.kv_cache = [ + ( + torch.empty( + (num_blocks, num_heads, head_size // x, BLOCK_SIZE, x), + dtype=dtype, + device=device, + ), + torch.empty( + (num_blocks, num_heads, head_size, BLOCK_SIZE), + dtype=dtype, + device=device, + ), + ) + for _ in range(num_layers) + ] def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int): input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device) diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index ef129e92..43f374e5 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -37,6 +37,9 @@ class FlashGPT2(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGPT2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index e27f0da2..e023c4e0 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -17,7 +17,7 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL class FlashLlama(FlashCausalLM): @@ -37,6 +37,9 @@ class FlashLlama(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 0fdda6d2..0c570487 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -16,7 +16,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -41,6 +41,9 @@ class BaseFlashMistral(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index d3871c2f..69d47e57 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -14,7 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -36,6 +36,9 @@ class FlashNeoXSharded(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index 187f26a8..e087dcf1 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -37,6 +37,9 @@ class FlashRWSharded(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index a8d84fca..9626af60 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -18,7 +18,7 @@ from text_generation_server.utils import ( Weights, ) -from text_generation_server.utils.import_utils import SYSTEM +from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL tracer = trace.get_tracer(__name__) @@ -40,6 +40,9 @@ class FlashSantacoderSharded(FlashCausalLM): elif SYSTEM == "xpu": device = torch.device(f"xpu:{rank}") dtype = torch.float16 if dtype is None else dtype + elif IPEX_AVAIL: + device = torch.device("cpu") + dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 3625e6f2..7d387563 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,6 +3,7 @@ import torch from datetime import timedelta from loguru import logger +from text_generation_server.utils.import_utils import IPEX_AVAIL # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) @@ -57,14 +58,7 @@ def initialize_torch_distributed(): options.is_high_priority_stream = True options._timeout = timedelta(seconds=60) else: - try: - import oneccl_bindings_for_pytorch - - backend = "ccl" - if os.getenv("CCL_WORKER_COUNT", None) is None: - os.environ["CCL_WORKER_COUNT"] = str(1) - except ImportError: - backend = "gloo" + backend = "gloo" options = None if WORLD_SIZE == 1: @@ -75,13 +69,24 @@ def initialize_torch_distributed(): if not torch.distributed.is_initialized(): # Call the init process. - torch.distributed.init_process_group( - backend=backend, - world_size=WORLD_SIZE, - rank=RANK, - timeout=timedelta(seconds=60), - pg_options=options, - ) + if IPEX_AVAIL: + import intel_extension_for_pytorch as ipex + + ipex.distributed.init_process_group( + backend="ccl", + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) + else: + torch.distributed.init_process_group( + backend=backend, + world_size=WORLD_SIZE, + rank=RANK, + timeout=timedelta(seconds=60), + pg_options=options, + ) else: logger.warning("torch.distributed is already initialized.") diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index c3929392..a244417a 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -3,13 +3,12 @@ from loguru import logger import subprocess -def is_xpu_available(): +def is_ipex_available(): try: import intel_extension_for_pytorch except ImportError: return False - - return hasattr(torch, "xpu") and torch.xpu.is_available() + return True def get_cuda_free_memory(device, memory_fraction): @@ -29,6 +28,16 @@ def get_xpu_free_memory(device, memory_fraction): return free_memory +def get_cpu_free_memory(device, memory_fraction): + import psutil + from text_generation_server.utils.dist import WORLD_SIZE + + mem = psutil.virtual_memory() + free_memory = int(mem.available * 0.95 / WORLD_SIZE) + return free_memory + + +IPEX_AVAIL = is_ipex_available() SYSTEM = None if torch.version.hip is not None: SYSTEM = "rocm" @@ -40,7 +49,7 @@ elif torch.version.cuda is not None and torch.cuda.is_available(): empty_cache = torch.cuda.empty_cache synchronize = torch.cuda.synchronize get_free_memory = get_cuda_free_memory -elif is_xpu_available(): +elif IPEX_AVAIL and hasattr(torch, "xpu") and torch.xpu.is_available(): SYSTEM = "xpu" empty_cache = torch.xpu.empty_cache synchronize = torch.xpu.synchronize @@ -53,5 +62,5 @@ else: empty_cache = noop synchronize = noop - get_free_memory = noop + get_free_memory = get_cpu_free_memory logger.info(f"Detected system {SYSTEM}")