diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 90fb9d45..0eb198f4 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -178,6 +178,6 @@ jobs: export DOCKER_VOLUME=/mnt/cache export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }} export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} - export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + export HF_TOKEN=${{ secrets.HF_TOKEN }} echo $DOCKER_IMAGE pytest -s -vv integration-tests diff --git a/.github/workflows/client-tests.yaml b/.github/workflows/client-tests.yaml index ef7c217c..ff2928c4 100644 --- a/.github/workflows/client-tests.yaml +++ b/.github/workflows/client-tests.yaml @@ -22,5 +22,5 @@ jobs: - name: Run tests run: | pip install pytest pytest-asyncio - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + export HF_TOKEN=${{ secrets.HF_TOKEN }} make python-client-tests diff --git a/.github/workflows/integration_tests.yaml b/.github/workflows/integration_tests.yaml index 4e111afe..59a8d304 100644 --- a/.github/workflows/integration_tests.yaml +++ b/.github/workflows/integration_tests.yaml @@ -37,5 +37,5 @@ jobs: export DOCKER_VOLUME=/mnt/cache export DOCKER_IMAGE=${{ inputs.docker_image }} export DOCKER_DEVICES=${{ inputs.docker_devices }} - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + export HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -s -vv integration-tests diff --git a/.github/workflows/load_test.yaml b/.github/workflows/load_test.yaml index a10c9428..637df472 100644 --- a/.github/workflows/load_test.yaml +++ b/.github/workflows/load_test.yaml @@ -28,7 +28,7 @@ jobs: - name: Start starcoder run: | - docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768 + docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768 sleep 10 wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index e21344d1..f983b6ed 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -72,7 +72,7 @@ jobs: - name: Run server tests run: | pip install pytest - export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + export HF_TOKEN=${{ secrets.HF_TOKEN }} pytest -s -vv server/tests - name: Pre-commit checks run: | diff --git a/clients/python/text_generation/types.py b/clients/python/text_generation/types.py index 497468d9..a56edaca 100644 --- a/clients/python/text_generation/types.py +++ b/clients/python/text_generation/types.py @@ -455,6 +455,6 @@ class DeployedModel(BaseModel): # Disable warning for use of `model_` prefix in `model_id`. Be mindful about adding members # with model_ prefixes, since this disables guardrails for colliding fields: # https://github.com/pydantic/pydantic/issues/9177 - model_config = ConfigDict(protected_namespaces=()) + model_config = ConfigDict(protected_namespaces=()) model_id: str sha: str diff --git a/server/text_generation_server/layers/attention/__init__.py b/server/text_generation_server/layers/attention/__init__.py index a53c8e3b..e74180e7 100644 --- a/server/text_generation_server/layers/attention/__init__.py +++ b/server/text_generation_server/layers/attention/__init__.py @@ -1,4 +1,4 @@ -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM import os if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": @@ -7,7 +7,7 @@ if SYSTEM == "cuda": from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING elif SYSTEM == "rocm": from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING -elif IPEX_AVAIL: - from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING +elif SYSTEM == "ipex": + from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING else: raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") diff --git a/server/text_generation_server/layers/attention/xpu.py b/server/text_generation_server/layers/attention/ipex.py similarity index 100% rename from server/text_generation_server/layers/attention/xpu.py rename to server/text_generation_server/layers/attention/ipex.py diff --git a/server/text_generation_server/layers/layernorm.py b/server/text_generation_server/layers/layernorm.py index 93e83dfa..ce5289f9 100644 --- a/server/text_generation_server/layers/layernorm.py +++ b/server/text_generation_server/layers/layernorm.py @@ -3,7 +3,6 @@ from torch import nn from accelerate import init_empty_weights from text_generation_server.utils.import_utils import ( SYSTEM, - IPEX_AVAIL, ) @@ -83,7 +82,7 @@ elif SYSTEM == "rocm": return super().forward(hidden_states), residual -elif IPEX_AVAIL: +elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex class FastLayerNorm(nn.LayerNorm): @@ -112,7 +111,7 @@ class FastRMSNorm(nn.Module): return cls(weight, eps) def forward(self, hidden_states, residual=None): - if IPEX_AVAIL: + if SYSTEM == "ipex": out = ipex.llm.functional.add_rms_norm( residual, hidden_states, diff --git a/server/text_generation_server/layers/rotary.py b/server/text_generation_server/layers/rotary.py index 1892cf69..b14005e6 100644 --- a/server/text_generation_server/layers/rotary.py +++ b/server/text_generation_server/layers/rotary.py @@ -2,14 +2,14 @@ import os import torch from torch import nn -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM if SYSTEM == "cuda": from flash_attn.layers.rotary import RotaryEmbedding import rotary_emb elif SYSTEM == "rocm": from vllm._C import ops -elif IPEX_AVAIL: +elif SYSTEM == "ipex": import intel_extension_for_pytorch as ipex @@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module): # Inplace operation, updating query and key. ops.rotary_embedding(query, key, head_size, cos, sin, True) - elif IPEX_AVAIL: + elif SYSTEM == "ipex": ipex.llm.functional.rotary_embedding( query, key, sin, cos, query.size(-1), True ) diff --git a/server/text_generation_server/layers/tensor_parallel.py b/server/text_generation_server/layers/tensor_parallel.py index 510dc2c6..038de258 100644 --- a/server/text_generation_server/layers/tensor_parallel.py +++ b/server/text_generation_server/layers/tensor_parallel.py @@ -3,9 +3,9 @@ from torch.nn import functional as F from typing import Iterable, List from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.exl2 import Exl2Weight -from text_generation_server.utils.import_utils import IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM -if IPEX_AVAIL: +if SYSTEM == "ipex": import intel_extension_for_pytorch as ipex @@ -100,7 +100,7 @@ class TensorParallelHead(SuperLayer): local_out = gather_input.T torch.mm(input, self.linear.weight.T, out=local_out) - if IPEX_AVAIL: + if SYSTEM == "ipex": ipex.distributed.all_gather_into_tensor( world_out, gather_input, group=self.process_group ) @@ -117,7 +117,7 @@ class TensorParallelHead(SuperLayer): world_output = [ torch.empty_like(output) for _ in range(self.process_group.size()) ] - if IPEX_AVAIL: + if SYSTEM == "ipex": ipex.distributed.all_gather(world_output, output, group=self.process_group) else: torch.distributed.all_gather(world_output, output, group=self.process_group) @@ -217,7 +217,7 @@ class TensorParallelRowLinear(SuperLayer): def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: out = super().forward(input) if self.process_group.size() > 1 and reduce: - if IPEX_AVAIL: + if SYSTEM == "ipex": ipex.distributed.all_reduce(out, group=self.process_group) else: torch.distributed.all_reduce(out, group=self.process_group) @@ -257,7 +257,7 @@ class TensorParallelEmbedding(torch.nn.Module): ) out = torch.nn.functional.embedding(input, self.weight) if self.reduce and self.process_group.size() > 1: - if IPEX_AVAIL: + if SYSTEM == "ipex": ipex.distributed.all_reduce(out, group=self.process_group) else: torch.distributed.all_reduce(out, group=self.process_group) diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index 56292250..f81bfa10 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -20,9 +20,9 @@ from torch import nn from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig from typing import Optional, List, Tuple, Any -from text_generation_server.utils.import_utils import IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM -if not IPEX_AVAIL: +if SYSTEM != "ipex": from vllm.model_executor.layers.fused_moe import fused_moe from text_generation_server.layers.attention import ( diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 6ea95411..2f7619af 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -24,9 +24,9 @@ import torch.distributed import numpy as np from torch import nn -from text_generation_server.utils.import_utils import IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM -if not IPEX_AVAIL: +if SYSTEM != "ipex": from vllm.model_executor.layers.fused_moe import fused_moe from transformers.activations import ACT2FN from transformers.configuration_utils import PretrainedConfig diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 633b066b..aa43107f 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -15,7 +15,7 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from text_generation_server.utils.chunks import concat_text_chunks -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM from text_generation_server.models import Model from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.dist import RANK @@ -768,12 +768,9 @@ class FlashCausalLM(Model): empty_cache() element_size = torch.tensor([], dtype=dtype).element_size() - if SYSTEM == "xpu": - x = 1 - else: - x = BLOCK_SIZE // element_size + x = BLOCK_SIZE // element_size - if IPEX_AVAIL and SYSTEM == "cpu": + if SYSTEM == "ipex" and device == torch.device("cpu"): self.kv_cache = [ ( torch.empty( diff --git a/server/text_generation_server/models/flash_gpt2.py b/server/text_generation_server/models/flash_gpt2.py index 43f374e5..75c7203a 100644 --- a/server/text_generation_server/models/flash_gpt2.py +++ b/server/text_generation_server/models/flash_gpt2.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -34,12 +34,12 @@ class FlashGPT2(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashGPT2 is only available on GPU") diff --git a/server/text_generation_server/models/flash_llama.py b/server/text_generation_server/models/flash_llama.py index e023c4e0..76c522e3 100644 --- a/server/text_generation_server/models/flash_llama.py +++ b/server/text_generation_server/models/flash_llama.py @@ -17,7 +17,7 @@ from text_generation_server.utils import ( tracer = trace.get_tracer(__name__) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM class FlashLlama(FlashCausalLM): @@ -34,12 +34,12 @@ class FlashLlama(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashLlama is only available on GPU") diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 0c570487..78a09cf5 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -16,7 +16,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -38,12 +38,12 @@ class BaseFlashMistral(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashMistral is only available on GPU") diff --git a/server/text_generation_server/models/flash_neox.py b/server/text_generation_server/models/flash_neox.py index 69d47e57..9c82bf52 100644 --- a/server/text_generation_server/models/flash_neox.py +++ b/server/text_generation_server/models/flash_neox.py @@ -14,7 +14,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -33,12 +33,12 @@ class FlashNeoXSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashNeoX is only available on GPU") diff --git a/server/text_generation_server/models/flash_rw.py b/server/text_generation_server/models/flash_rw.py index e087dcf1..e8087f23 100644 --- a/server/text_generation_server/models/flash_rw.py +++ b/server/text_generation_server/models/flash_rw.py @@ -15,7 +15,7 @@ from text_generation_server.utils import ( weight_files, Weights, ) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -34,12 +34,12 @@ class FlashRWSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashRW is only available on GPU") diff --git a/server/text_generation_server/models/flash_santacoder.py b/server/text_generation_server/models/flash_santacoder.py index 9626af60..83a6b92c 100644 --- a/server/text_generation_server/models/flash_santacoder.py +++ b/server/text_generation_server/models/flash_santacoder.py @@ -18,7 +18,7 @@ from text_generation_server.utils import ( Weights, ) -from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM tracer = trace.get_tracer(__name__) @@ -37,12 +37,12 @@ class FlashSantacoderSharded(FlashCausalLM): if torch.cuda.is_available(): device = torch.device(f"cuda:{rank}") dtype = torch.float16 if dtype is None else dtype - elif SYSTEM == "xpu": - device = torch.device(f"xpu:{rank}") + elif SYSTEM == "ipex": + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device = torch.device(f"xpu:{rank}") + else: + device = torch.device("cpu") dtype = torch.float16 if dtype is None else dtype - elif IPEX_AVAIL: - device = torch.device("cpu") - dtype = torch.bfloat16 if dtype is None else dtype else: raise NotImplementedError("FlashSantacoderSharded is only available on GPU") diff --git a/server/text_generation_server/utils/dist.py b/server/text_generation_server/utils/dist.py index 7d387563..36d63e86 100644 --- a/server/text_generation_server/utils/dist.py +++ b/server/text_generation_server/utils/dist.py @@ -3,7 +3,7 @@ import torch from datetime import timedelta from loguru import logger -from text_generation_server.utils.import_utils import IPEX_AVAIL +from text_generation_server.utils.import_utils import SYSTEM # Tensor Parallelism settings RANK = int(os.getenv("RANK", "0")) @@ -69,7 +69,7 @@ def initialize_torch_distributed(): if not torch.distributed.is_initialized(): # Call the init process. - if IPEX_AVAIL: + if SYSTEM == "ipex": import intel_extension_for_pytorch as ipex ipex.distributed.init_process_group( diff --git a/server/text_generation_server/utils/import_utils.py b/server/text_generation_server/utils/import_utils.py index a244417a..6d921721 100644 --- a/server/text_generation_server/utils/import_utils.py +++ b/server/text_generation_server/utils/import_utils.py @@ -37,7 +37,10 @@ def get_cpu_free_memory(device, memory_fraction): return free_memory -IPEX_AVAIL = is_ipex_available() +def noop(*args, **kwargs): + pass + + SYSTEM = None if torch.version.hip is not None: SYSTEM = "rocm" @@ -49,17 +52,19 @@ elif torch.version.cuda is not None and torch.cuda.is_available(): empty_cache = torch.cuda.empty_cache synchronize = torch.cuda.synchronize get_free_memory = get_cuda_free_memory -elif IPEX_AVAIL and hasattr(torch, "xpu") and torch.xpu.is_available(): - SYSTEM = "xpu" - empty_cache = torch.xpu.empty_cache - synchronize = torch.xpu.synchronize - get_free_memory = get_xpu_free_memory +elif is_ipex_available(): + SYSTEM = "ipex" + if hasattr(torch, "xpu") and torch.xpu.is_available(): + empty_cache = torch.xpu.empty_cache + synchronize = torch.xpu.synchronize + get_free_memory = get_xpu_free_memory + else: + empty_cache = noop + synchronize = noop + get_free_memory = get_cpu_free_memory else: SYSTEM = "cpu" - def noop(*args, **kwargs): - pass - empty_cache = noop synchronize = noop get_free_memory = get_cpu_free_memory