Cpu tgi (#1936)
* add CPU tgi support Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * ipex distributed ops support Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Funtowicz Morgan <mfuntowicz@users.noreply.github.com>
This commit is contained in:
parent
b69f078041
commit
b64c70c9e7
|
@ -1,3 +1,5 @@
|
||||||
|
ARG PLATFORM=xpu
|
||||||
|
|
||||||
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
FROM lukemathwalker/cargo-chef:latest-rust-1.79 AS chef
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
@ -37,7 +39,8 @@ RUN cargo build --profile release-opt
|
||||||
|
|
||||||
|
|
||||||
# Text Generation Inference base image for Intel
|
# Text Generation Inference base image for Intel
|
||||||
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as base
|
|
||||||
|
FROM intel/intel-extension-for-pytorch:2.1.30-xpu as xpu
|
||||||
|
|
||||||
USER root
|
USER root
|
||||||
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
# libssl.so.1.1 is not installed on Ubuntu 22.04 by default, install it
|
||||||
|
@ -59,7 +62,7 @@ ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
|
||||||
WORKDIR /usr/src
|
WORKDIR /usr/src
|
||||||
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
RUN wget https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/xpu/torch-2.1.0.post1%2Bcxx11.abi-cp310-cp310-linux_x86_64.whl && pip install torch-2.1.0.post1+cxx11.abi-cp310-cp310-linux_x86_64.whl
|
||||||
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b group_rope origin/dev/gqa_rope
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout -b distributed origin/dev/distributed
|
||||||
|
|
||||||
# Install server
|
# Install server
|
||||||
COPY proto proto
|
COPY proto proto
|
||||||
|
@ -89,8 +92,84 @@ COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/loca
|
||||||
# Install launcher
|
# Install launcher
|
||||||
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
# Final image
|
|
||||||
FROM base
|
|
||||||
|
|
||||||
|
# Text Generation Inference base image for Intel-cpu
|
||||||
|
FROM ubuntu:22.04 as cpu
|
||||||
|
|
||||||
|
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
||||||
|
curl \
|
||||||
|
ca-certificates \
|
||||||
|
make \
|
||||||
|
g++ \
|
||||||
|
git \
|
||||||
|
wget \
|
||||||
|
cmake
|
||||||
|
|
||||||
|
ENV HUGGINGFACE_HUB_CACHE=/data \
|
||||||
|
HF_HUB_ENABLE_HF_TRANSFER=1 \
|
||||||
|
PORT=80
|
||||||
|
|
||||||
|
ARG MAMBA_VERSION=23.1.0-1
|
||||||
|
ARG PYTHON_VERSION='3.10.10'
|
||||||
|
# Automatically set by buildx
|
||||||
|
ARG TARGETPLATFORM
|
||||||
|
ENV PATH /opt/conda/bin:$PATH
|
||||||
|
|
||||||
|
# TGI seem to require libssl.so.1.1 instead of libssl.so.3 so we can't use ubuntu 22.04. Ubuntu 20.04 has python==3.8, and TGI requires python>=3.9, hence the need for miniconda.
|
||||||
|
# Install mamba
|
||||||
|
# translating Docker's TARGETPLATFORM into mamba arches
|
||||||
|
RUN case ${TARGETPLATFORM} in \
|
||||||
|
"linux/arm64") MAMBA_ARCH=aarch64 ;; \
|
||||||
|
*) MAMBA_ARCH=x86_64 ;; \
|
||||||
|
esac && \
|
||||||
|
curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
|
||||||
|
RUN chmod +x ~/mambaforge.sh && \
|
||||||
|
bash ~/mambaforge.sh -b -p /opt/conda && \
|
||||||
|
rm ~/mambaforge.sh
|
||||||
|
|
||||||
|
RUN conda install -c conda-forge gperftools mkl
|
||||||
|
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torch-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchvision-0.19.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
RUN pip install https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.4.0.dev20240612%2Bcpu-cp310-cp310-linux_x86_64.whl
|
||||||
|
|
||||||
|
WORKDIR /usr/src
|
||||||
|
|
||||||
|
RUN git clone https://github.com/intel/intel-extension-for-pytorch && cd intel-extension-for-pytorch && git checkout eda7a7c42df6f9a64e0de9c2b69304ee02f2c32a
|
||||||
|
|
||||||
|
RUN git clone https://github.com/intel/torch-ccl.git && cd torch-ccl && git checkout ccl_torch_dev_0131
|
||||||
|
|
||||||
|
RUN cd intel-extension-for-pytorch && git submodule sync && git submodule update --init --recursive && python setup.py install
|
||||||
|
|
||||||
|
RUN cd torch-ccl && git submodule sync && git submodule update --init --recursive && pip install .
|
||||||
|
|
||||||
|
ENV LD_PRELOAD=/opt/conda/lib/libtcmalloc.so:/opt/conda/lib/libiomp5.so
|
||||||
|
ENV CCL_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||||
|
ENV I_MPI_ROOT=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch
|
||||||
|
ENV FI_PROVIDER_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib/prov:/usr/lib64/libfabric
|
||||||
|
ENV LD_LIBRARY_PATH=/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/opt/mpi/libfabric/lib:/opt/conda/lib/python3.10/site-packages/oneccl_bindings_for_pytorch/lib
|
||||||
|
ENV KMP_BLOCKTIME=1
|
||||||
|
ENV KMP_TPAUSE=0
|
||||||
|
ENV KMP_FORKJOIN_BARRIER_PATTERN=dist,dist
|
||||||
|
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
|
||||||
|
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
|
||||||
|
|
||||||
|
# Install server
|
||||||
|
COPY proto proto
|
||||||
|
COPY server server
|
||||||
|
COPY server/Makefile server/Makefile
|
||||||
|
RUN cd server && \
|
||||||
|
make gen-server && \
|
||||||
|
pip install -r requirements_intel.txt && \
|
||||||
|
pip install ".[accelerate, peft, outlines]" --no-cache-dir
|
||||||
|
|
||||||
|
# Install benchmarker
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-benchmark /usr/local/bin/text-generation-benchmark
|
||||||
|
# Install router
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-router /usr/local/bin/text-generation-router
|
||||||
|
# Install launcher
|
||||||
|
COPY --from=builder /usr/src/target/release-opt/text-generation-launcher /usr/local/bin/text-generation-launcher
|
||||||
|
|
||||||
|
FROM ${PLATFORM} as final
|
||||||
ENTRYPOINT ["text-generation-launcher"]
|
ENTRYPOINT ["text-generation-launcher"]
|
||||||
CMD ["--json-output"]
|
CMD ["--json-output"]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
||||||
import os
|
import os
|
||||||
|
|
||||||
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
|
||||||
|
@ -7,7 +7,7 @@ if SYSTEM == "cuda":
|
||||||
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
elif SYSTEM == "xpu":
|
elif IPEX_AVAIL:
|
||||||
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
|
||||||
else:
|
else:
|
||||||
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
import torch
|
import torch
|
||||||
|
from text_generation_server.models.flash_causal_lm import BLOCK_SIZE
|
||||||
|
|
||||||
SUPPORTS_WINDOWING = False
|
SUPPORTS_WINDOWING = False
|
||||||
|
|
||||||
|
@ -56,8 +57,6 @@ def paged_attention(
|
||||||
input_lengths: torch.Tensor,
|
input_lengths: torch.Tensor,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
):
|
):
|
||||||
query = query.contiguous()
|
|
||||||
block_size = value_cache.shape[3]
|
|
||||||
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
return ipex.llm.modules.PagedAttention.single_query_cached_kv_attention(
|
||||||
out,
|
out,
|
||||||
query,
|
query,
|
||||||
|
@ -67,7 +66,7 @@ def paged_attention(
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
block_tables,
|
block_tables,
|
||||||
input_lengths,
|
input_lengths,
|
||||||
block_size,
|
BLOCK_SIZE,
|
||||||
max_s,
|
max_s,
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,6 +3,7 @@ from torch import nn
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
from text_generation_server.utils.import_utils import (
|
from text_generation_server.utils.import_utils import (
|
||||||
SYSTEM,
|
SYSTEM,
|
||||||
|
IPEX_AVAIL,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,18 +83,20 @@ elif SYSTEM == "rocm":
|
||||||
|
|
||||||
return super().forward(hidden_states), residual
|
return super().forward(hidden_states), residual
|
||||||
|
|
||||||
elif SYSTEM == "xpu":
|
elif IPEX_AVAIL:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
class FastLayerNorm(nn.LayerNorm):
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
res_out = hidden_states
|
|
||||||
out = ipex.llm.functional.add_layer_norm(
|
out = ipex.llm.functional.add_layer_norm(
|
||||||
residual, hidden_states, self.weight, self.bias, self.eps, True
|
residual,
|
||||||
|
hidden_states,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
self.eps,
|
||||||
|
residual is not None,
|
||||||
)
|
)
|
||||||
if residual is not None:
|
return out, residual if residual is not None else hidden_states
|
||||||
res_out = residual
|
|
||||||
return out, res_out
|
|
||||||
|
|
||||||
|
|
||||||
class FastRMSNorm(nn.Module):
|
class FastRMSNorm(nn.Module):
|
||||||
|
@ -109,19 +112,16 @@ class FastRMSNorm(nn.Module):
|
||||||
return cls(weight, eps)
|
return cls(weight, eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, residual=None):
|
def forward(self, hidden_states, residual=None):
|
||||||
if SYSTEM == "xpu":
|
if IPEX_AVAIL:
|
||||||
residual_out = hidden_states
|
|
||||||
out = ipex.llm.functional.add_rms_norm(
|
out = ipex.llm.functional.add_rms_norm(
|
||||||
residual,
|
residual,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
self.weight,
|
self.weight,
|
||||||
None,
|
None,
|
||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
True,
|
residual is not None,
|
||||||
)
|
)
|
||||||
if residual is not None:
|
return out, residual if residual is not None else hidden_states
|
||||||
residual_out = residual
|
|
||||||
return out, residual_out
|
|
||||||
elif hidden_states.shape[-1] > 8192:
|
elif hidden_states.shape[-1] > 8192:
|
||||||
if residual is not None:
|
if residual is not None:
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
|
|
|
@ -2,14 +2,14 @@ import os
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
||||||
|
|
||||||
if SYSTEM == "cuda":
|
if SYSTEM == "cuda":
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
import rotary_emb
|
import rotary_emb
|
||||||
elif SYSTEM == "rocm":
|
elif SYSTEM == "rocm":
|
||||||
from vllm._C import ops
|
from vllm._C import ops
|
||||||
elif SYSTEM == "xpu":
|
elif IPEX_AVAIL:
|
||||||
import intel_extension_for_pytorch as ipex
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module):
|
||||||
|
|
||||||
# Inplace operation, updating query and key.
|
# Inplace operation, updating query and key.
|
||||||
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
ops.rotary_embedding(query, key, head_size, cos, sin, True)
|
||||||
elif SYSTEM == "xpu":
|
elif IPEX_AVAIL:
|
||||||
ipex.llm.functional.rotary_embedding(
|
ipex.llm.functional.rotary_embedding(
|
||||||
query, key, sin, cos, query.size(-1), True
|
query, key, sin, cos, query.size(-1), True
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,6 +3,10 @@ from torch.nn import functional as F
|
||||||
from typing import Iterable, List
|
from typing import Iterable, List
|
||||||
from text_generation_server.layers.linear import get_linear, FastLinear
|
from text_generation_server.layers.linear import get_linear, FastLinear
|
||||||
from text_generation_server.layers.exl2 import Exl2Weight
|
from text_generation_server.layers.exl2 import Exl2Weight
|
||||||
|
from text_generation_server.utils.import_utils import IPEX_AVAIL
|
||||||
|
|
||||||
|
if IPEX_AVAIL:
|
||||||
|
import intel_extension_for_pytorch as ipex
|
||||||
|
|
||||||
|
|
||||||
class LayerConcat(torch.nn.Module):
|
class LayerConcat(torch.nn.Module):
|
||||||
|
@ -96,10 +100,14 @@ class TensorParallelHead(SuperLayer):
|
||||||
local_out = gather_input.T
|
local_out = gather_input.T
|
||||||
|
|
||||||
torch.mm(input, self.linear.weight.T, out=local_out)
|
torch.mm(input, self.linear.weight.T, out=local_out)
|
||||||
|
if IPEX_AVAIL:
|
||||||
torch.distributed.all_gather_into_tensor(
|
ipex.distributed.all_gather_into_tensor(
|
||||||
world_out, gather_input, group=self.process_group
|
world_out, gather_input, group=self.process_group
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_gather_into_tensor(
|
||||||
|
world_out, gather_input, group=self.process_group
|
||||||
|
)
|
||||||
|
|
||||||
if input.shape[0] == 1:
|
if input.shape[0] == 1:
|
||||||
return world_out
|
return world_out
|
||||||
|
@ -109,7 +117,10 @@ class TensorParallelHead(SuperLayer):
|
||||||
world_output = [
|
world_output = [
|
||||||
torch.empty_like(output) for _ in range(self.process_group.size())
|
torch.empty_like(output) for _ in range(self.process_group.size())
|
||||||
]
|
]
|
||||||
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
if IPEX_AVAIL:
|
||||||
|
ipex.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_gather(world_output, output, group=self.process_group)
|
||||||
world_output = torch.cat(world_output, dim=-1)
|
world_output = torch.cat(world_output, dim=-1)
|
||||||
return world_output
|
return world_output
|
||||||
|
|
||||||
|
@ -206,7 +217,10 @@ class TensorParallelRowLinear(SuperLayer):
|
||||||
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
|
||||||
out = super().forward(input)
|
out = super().forward(input)
|
||||||
if self.process_group.size() > 1 and reduce:
|
if self.process_group.size() > 1 and reduce:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if IPEX_AVAIL:
|
||||||
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
@ -243,5 +257,8 @@ class TensorParallelEmbedding(torch.nn.Module):
|
||||||
)
|
)
|
||||||
out = torch.nn.functional.embedding(input, self.weight)
|
out = torch.nn.functional.embedding(input, self.weight)
|
||||||
if self.reduce and self.process_group.size() > 1:
|
if self.reduce and self.process_group.size() > 1:
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if IPEX_AVAIL:
|
||||||
|
ipex.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
else:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
return out
|
return out
|
||||||
|
|
|
@ -20,9 +20,9 @@ from torch import nn
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
from typing import Optional, List, Tuple, Any
|
from typing import Optional, List, Tuple, Any
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import IPEX_AVAIL
|
||||||
|
|
||||||
if SYSTEM != "xpu":
|
if not IPEX_AVAIL:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
|
|
||||||
from text_generation_server.layers.attention import (
|
from text_generation_server.layers.attention import (
|
||||||
|
|
|
@ -24,9 +24,9 @@ import torch.distributed
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import IPEX_AVAIL
|
||||||
|
|
||||||
if SYSTEM != "xpu":
|
if not IPEX_AVAIL:
|
||||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
|
@ -15,7 +15,7 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
|
||||||
|
|
||||||
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
|
||||||
from text_generation_server.utils.chunks import concat_text_chunks
|
from text_generation_server.utils.chunks import concat_text_chunks
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
||||||
from text_generation_server.models import Model
|
from text_generation_server.models import Model
|
||||||
from text_generation_server.utils.tokens import batch_top_tokens
|
from text_generation_server.utils.tokens import batch_top_tokens
|
||||||
from text_generation_server.utils.dist import RANK
|
from text_generation_server.utils.dist import RANK
|
||||||
|
@ -773,21 +773,38 @@ class FlashCausalLM(Model):
|
||||||
else:
|
else:
|
||||||
x = BLOCK_SIZE // element_size
|
x = BLOCK_SIZE // element_size
|
||||||
|
|
||||||
self.kv_cache = [
|
if IPEX_AVAIL and SYSTEM == "cpu":
|
||||||
(
|
self.kv_cache = [
|
||||||
torch.empty(
|
(
|
||||||
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
torch.empty(
|
||||||
dtype=dtype,
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
device=device,
|
dtype=dtype,
|
||||||
),
|
device=device,
|
||||||
torch.empty(
|
),
|
||||||
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
torch.empty(
|
||||||
dtype=dtype,
|
(num_blocks, num_heads, BLOCK_SIZE, head_size),
|
||||||
device=device,
|
dtype=dtype,
|
||||||
),
|
device=device,
|
||||||
)
|
),
|
||||||
for _ in range(num_layers)
|
)
|
||||||
]
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
self.kv_cache = [
|
||||||
|
(
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size // x, BLOCK_SIZE, x),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
torch.empty(
|
||||||
|
(num_blocks, num_heads, head_size, BLOCK_SIZE),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
def cuda_graph_warmup(self, bs: int, max_s: int, max_bt: int):
|
||||||
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
input_ids = torch.zeros(bs, dtype=torch.int64, device=self.device)
|
||||||
|
|
|
@ -15,7 +15,7 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -37,6 +37,9 @@ class FlashGPT2(FlashCausalLM):
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IPEX_AVAIL:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
raise NotImplementedError("FlashGPT2 is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from text_generation_server.utils import (
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
||||||
|
|
||||||
|
|
||||||
class FlashLlama(FlashCausalLM):
|
class FlashLlama(FlashCausalLM):
|
||||||
|
@ -37,6 +37,9 @@ class FlashLlama(FlashCausalLM):
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IPEX_AVAIL:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashLlama is only available on GPU")
|
raise NotImplementedError("FlashLlama is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -41,6 +41,9 @@ class BaseFlashMistral(FlashCausalLM):
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IPEX_AVAIL:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashMistral is only available on GPU")
|
raise NotImplementedError("FlashMistral is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -36,6 +36,9 @@ class FlashNeoXSharded(FlashCausalLM):
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IPEX_AVAIL:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashNeoX is only available on GPU")
|
raise NotImplementedError("FlashNeoX is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ from text_generation_server.utils import (
|
||||||
weight_files,
|
weight_files,
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -37,6 +37,9 @@ class FlashRWSharded(FlashCausalLM):
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IPEX_AVAIL:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashRW is only available on GPU")
|
raise NotImplementedError("FlashRW is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from text_generation_server.utils import (
|
||||||
Weights,
|
Weights,
|
||||||
)
|
)
|
||||||
|
|
||||||
from text_generation_server.utils.import_utils import SYSTEM
|
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL
|
||||||
|
|
||||||
tracer = trace.get_tracer(__name__)
|
tracer = trace.get_tracer(__name__)
|
||||||
|
|
||||||
|
@ -40,6 +40,9 @@ class FlashSantacoderSharded(FlashCausalLM):
|
||||||
elif SYSTEM == "xpu":
|
elif SYSTEM == "xpu":
|
||||||
device = torch.device(f"xpu:{rank}")
|
device = torch.device(f"xpu:{rank}")
|
||||||
dtype = torch.float16 if dtype is None else dtype
|
dtype = torch.float16 if dtype is None else dtype
|
||||||
|
elif IPEX_AVAIL:
|
||||||
|
device = torch.device("cpu")
|
||||||
|
dtype = torch.bfloat16 if dtype is None else dtype
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
raise NotImplementedError("FlashSantacoderSharded is only available on GPU")
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch
|
||||||
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from text_generation_server.utils.import_utils import IPEX_AVAIL
|
||||||
|
|
||||||
# Tensor Parallelism settings
|
# Tensor Parallelism settings
|
||||||
RANK = int(os.getenv("RANK", "0"))
|
RANK = int(os.getenv("RANK", "0"))
|
||||||
|
@ -57,14 +58,7 @@ def initialize_torch_distributed():
|
||||||
options.is_high_priority_stream = True
|
options.is_high_priority_stream = True
|
||||||
options._timeout = timedelta(seconds=60)
|
options._timeout = timedelta(seconds=60)
|
||||||
else:
|
else:
|
||||||
try:
|
backend = "gloo"
|
||||||
import oneccl_bindings_for_pytorch
|
|
||||||
|
|
||||||
backend = "ccl"
|
|
||||||
if os.getenv("CCL_WORKER_COUNT", None) is None:
|
|
||||||
os.environ["CCL_WORKER_COUNT"] = str(1)
|
|
||||||
except ImportError:
|
|
||||||
backend = "gloo"
|
|
||||||
options = None
|
options = None
|
||||||
|
|
||||||
if WORLD_SIZE == 1:
|
if WORLD_SIZE == 1:
|
||||||
|
@ -75,13 +69,24 @@ def initialize_torch_distributed():
|
||||||
|
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
# Call the init process.
|
# Call the init process.
|
||||||
torch.distributed.init_process_group(
|
if IPEX_AVAIL:
|
||||||
backend=backend,
|
import intel_extension_for_pytorch as ipex
|
||||||
world_size=WORLD_SIZE,
|
|
||||||
rank=RANK,
|
ipex.distributed.init_process_group(
|
||||||
timeout=timedelta(seconds=60),
|
backend="ccl",
|
||||||
pg_options=options,
|
world_size=WORLD_SIZE,
|
||||||
)
|
rank=RANK,
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
torch.distributed.init_process_group(
|
||||||
|
backend=backend,
|
||||||
|
world_size=WORLD_SIZE,
|
||||||
|
rank=RANK,
|
||||||
|
timeout=timedelta(seconds=60),
|
||||||
|
pg_options=options,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("torch.distributed is already initialized.")
|
logger.warning("torch.distributed is already initialized.")
|
||||||
|
|
||||||
|
|
|
@ -3,13 +3,12 @@ from loguru import logger
|
||||||
import subprocess
|
import subprocess
|
||||||
|
|
||||||
|
|
||||||
def is_xpu_available():
|
def is_ipex_available():
|
||||||
try:
|
try:
|
||||||
import intel_extension_for_pytorch
|
import intel_extension_for_pytorch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
return False
|
return False
|
||||||
|
return True
|
||||||
return hasattr(torch, "xpu") and torch.xpu.is_available()
|
|
||||||
|
|
||||||
|
|
||||||
def get_cuda_free_memory(device, memory_fraction):
|
def get_cuda_free_memory(device, memory_fraction):
|
||||||
|
@ -29,6 +28,16 @@ def get_xpu_free_memory(device, memory_fraction):
|
||||||
return free_memory
|
return free_memory
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_free_memory(device, memory_fraction):
|
||||||
|
import psutil
|
||||||
|
from text_generation_server.utils.dist import WORLD_SIZE
|
||||||
|
|
||||||
|
mem = psutil.virtual_memory()
|
||||||
|
free_memory = int(mem.available * 0.95 / WORLD_SIZE)
|
||||||
|
return free_memory
|
||||||
|
|
||||||
|
|
||||||
|
IPEX_AVAIL = is_ipex_available()
|
||||||
SYSTEM = None
|
SYSTEM = None
|
||||||
if torch.version.hip is not None:
|
if torch.version.hip is not None:
|
||||||
SYSTEM = "rocm"
|
SYSTEM = "rocm"
|
||||||
|
@ -40,7 +49,7 @@ elif torch.version.cuda is not None and torch.cuda.is_available():
|
||||||
empty_cache = torch.cuda.empty_cache
|
empty_cache = torch.cuda.empty_cache
|
||||||
synchronize = torch.cuda.synchronize
|
synchronize = torch.cuda.synchronize
|
||||||
get_free_memory = get_cuda_free_memory
|
get_free_memory = get_cuda_free_memory
|
||||||
elif is_xpu_available():
|
elif IPEX_AVAIL and hasattr(torch, "xpu") and torch.xpu.is_available():
|
||||||
SYSTEM = "xpu"
|
SYSTEM = "xpu"
|
||||||
empty_cache = torch.xpu.empty_cache
|
empty_cache = torch.xpu.empty_cache
|
||||||
synchronize = torch.xpu.synchronize
|
synchronize = torch.xpu.synchronize
|
||||||
|
@ -53,5 +62,5 @@ else:
|
||||||
|
|
||||||
empty_cache = noop
|
empty_cache = noop
|
||||||
synchronize = noop
|
synchronize = noop
|
||||||
get_free_memory = noop
|
get_free_memory = get_cpu_free_memory
|
||||||
logger.info(f"Detected system {SYSTEM}")
|
logger.info(f"Detected system {SYSTEM}")
|
||||||
|
|
Loading…
Reference in New Issue