Removing IPEX_AVAIL. (#2115)

* Removing IPEX_AVAIL.

Chose to unify CPU and XPU under `ipex`. Most code is exactly similar
except for a very few spots.

The biggest number of spots is the kv-cache layout and the flash_xxx.py
files.
Since those files should be removed soon and factored away, we should
not need them.

* Forgot a few places.

* Unrelated change.

* Fixing HF_TOKEN.

* HF_TOKEN
This commit is contained in:
Nicolas Patry 2024-06-25 13:20:57 +02:00 committed by GitHub
parent 3f3b7ffd67
commit 9e2fdf57c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
22 changed files with 79 additions and 78 deletions

View File

@ -178,6 +178,6 @@ jobs:
export DOCKER_VOLUME=/mnt/cache export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }} export DOCKER_IMAGE=${{ needs.build-and-push.outputs.docker_image }}
export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }} export DOCKER_DEVICES=${{ needs.build-and-push.outputs.docker_devices }}
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
echo $DOCKER_IMAGE echo $DOCKER_IMAGE
pytest -s -vv integration-tests pytest -s -vv integration-tests

View File

@ -22,5 +22,5 @@ jobs:
- name: Run tests - name: Run tests
run: | run: |
pip install pytest pytest-asyncio pip install pytest pytest-asyncio
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
make python-client-tests make python-client-tests

View File

@ -37,5 +37,5 @@ jobs:
export DOCKER_VOLUME=/mnt/cache export DOCKER_VOLUME=/mnt/cache
export DOCKER_IMAGE=${{ inputs.docker_image }} export DOCKER_IMAGE=${{ inputs.docker_image }}
export DOCKER_DEVICES=${{ inputs.docker_devices }} export DOCKER_DEVICES=${{ inputs.docker_devices }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv integration-tests pytest -s -vv integration-tests

View File

@ -28,7 +28,7 @@ jobs:
- name: Start starcoder - name: Start starcoder
run: | run: |
docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768 docker run --name tgi-starcoder --rm --gpus all -p 3000:80 -v /mnt/cache:/data -e HF_TOKEN=${{ secrets.HF_TOKEN }} --pull always -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder --num-shard 2 --max-batch-total-tokens 32768
sleep 10 sleep 10
wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health wget --timeout 10 --retry-on-http-error --waitretry=1 --tries=240 http://localhost:3000/health

View File

@ -72,7 +72,7 @@ jobs:
- name: Run server tests - name: Run server tests
run: | run: |
pip install pytest pip install pytest
export HF_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HF_TOKEN=${{ secrets.HF_TOKEN }}
pytest -s -vv server/tests pytest -s -vv server/tests
- name: Pre-commit checks - name: Pre-commit checks
run: | run: |

View File

@ -1,4 +1,4 @@
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
import os import os
if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false": if os.getenv("USE_FLASH_ATTENTION", "").lower() == "false":
@ -7,7 +7,7 @@ if SYSTEM == "cuda":
from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .cuda import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .rocm import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
elif IPEX_AVAIL: elif SYSTEM == "ipex":
from .xpu import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING from .ipex import attention, paged_attention, reshape_and_cache, SUPPORTS_WINDOWING
else: else:
raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention") raise ImportError(f"System {SYSTEM} doesn't support flash/paged attention")

View File

@ -3,7 +3,6 @@ from torch import nn
from accelerate import init_empty_weights from accelerate import init_empty_weights
from text_generation_server.utils.import_utils import ( from text_generation_server.utils.import_utils import (
SYSTEM, SYSTEM,
IPEX_AVAIL,
) )
@ -83,7 +82,7 @@ elif SYSTEM == "rocm":
return super().forward(hidden_states), residual return super().forward(hidden_states), residual
elif IPEX_AVAIL: elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
class FastLayerNorm(nn.LayerNorm): class FastLayerNorm(nn.LayerNorm):
@ -112,7 +111,7 @@ class FastRMSNorm(nn.Module):
return cls(weight, eps) return cls(weight, eps)
def forward(self, hidden_states, residual=None): def forward(self, hidden_states, residual=None):
if IPEX_AVAIL: if SYSTEM == "ipex":
out = ipex.llm.functional.add_rms_norm( out = ipex.llm.functional.add_rms_norm(
residual, residual,
hidden_states, hidden_states,

View File

@ -2,14 +2,14 @@ import os
import torch import torch
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
if SYSTEM == "cuda": if SYSTEM == "cuda":
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
import rotary_emb import rotary_emb
elif SYSTEM == "rocm": elif SYSTEM == "rocm":
from vllm._C import ops from vllm._C import ops
elif IPEX_AVAIL: elif SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
@ -69,7 +69,7 @@ class PositionRotaryEmbedding(nn.Module):
# Inplace operation, updating query and key. # Inplace operation, updating query and key.
ops.rotary_embedding(query, key, head_size, cos, sin, True) ops.rotary_embedding(query, key, head_size, cos, sin, True)
elif IPEX_AVAIL: elif SYSTEM == "ipex":
ipex.llm.functional.rotary_embedding( ipex.llm.functional.rotary_embedding(
query, key, sin, cos, query.size(-1), True query, key, sin, cos, query.size(-1), True
) )

View File

@ -3,9 +3,9 @@ from torch.nn import functional as F
from typing import Iterable, List from typing import Iterable, List
from text_generation_server.layers.linear import get_linear, FastLinear from text_generation_server.layers.linear import get_linear, FastLinear
from text_generation_server.layers.exl2 import Exl2Weight from text_generation_server.layers.exl2 import Exl2Weight
from text_generation_server.utils.import_utils import IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
if IPEX_AVAIL: if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
@ -100,7 +100,7 @@ class TensorParallelHead(SuperLayer):
local_out = gather_input.T local_out = gather_input.T
torch.mm(input, self.linear.weight.T, out=local_out) torch.mm(input, self.linear.weight.T, out=local_out)
if IPEX_AVAIL: if SYSTEM == "ipex":
ipex.distributed.all_gather_into_tensor( ipex.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group world_out, gather_input, group=self.process_group
) )
@ -117,7 +117,7 @@ class TensorParallelHead(SuperLayer):
world_output = [ world_output = [
torch.empty_like(output) for _ in range(self.process_group.size()) torch.empty_like(output) for _ in range(self.process_group.size())
] ]
if IPEX_AVAIL: if SYSTEM == "ipex":
ipex.distributed.all_gather(world_output, output, group=self.process_group) ipex.distributed.all_gather(world_output, output, group=self.process_group)
else: else:
torch.distributed.all_gather(world_output, output, group=self.process_group) torch.distributed.all_gather(world_output, output, group=self.process_group)
@ -217,7 +217,7 @@ class TensorParallelRowLinear(SuperLayer):
def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor: def forward(self, input: torch.Tensor, reduce: bool = True) -> torch.Tensor:
out = super().forward(input) out = super().forward(input)
if self.process_group.size() > 1 and reduce: if self.process_group.size() > 1 and reduce:
if IPEX_AVAIL: if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group) ipex.distributed.all_reduce(out, group=self.process_group)
else: else:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
@ -257,7 +257,7 @@ class TensorParallelEmbedding(torch.nn.Module):
) )
out = torch.nn.functional.embedding(input, self.weight) out = torch.nn.functional.embedding(input, self.weight)
if self.reduce and self.process_group.size() > 1: if self.reduce and self.process_group.size() > 1:
if IPEX_AVAIL: if SYSTEM == "ipex":
ipex.distributed.all_reduce(out, group=self.process_group) ipex.distributed.all_reduce(out, group=self.process_group)
else: else:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)

View File

@ -20,9 +20,9 @@ from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple, Any from typing import Optional, List, Tuple, Any
from text_generation_server.utils.import_utils import IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
if not IPEX_AVAIL: if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from text_generation_server.layers.attention import ( from text_generation_server.layers.attention import (

View File

@ -24,9 +24,9 @@ import torch.distributed
import numpy as np import numpy as np
from torch import nn from torch import nn
from text_generation_server.utils.import_utils import IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
if not IPEX_AVAIL: if SYSTEM != "ipex":
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig

View File

@ -15,7 +15,7 @@ from typing import Iterable, Optional, Tuple, List, Type, Dict
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from text_generation_server.utils.chunks import concat_text_chunks from text_generation_server.utils.chunks import concat_text_chunks
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
from text_generation_server.models import Model from text_generation_server.models import Model
from text_generation_server.utils.tokens import batch_top_tokens from text_generation_server.utils.tokens import batch_top_tokens
from text_generation_server.utils.dist import RANK from text_generation_server.utils.dist import RANK
@ -768,12 +768,9 @@ class FlashCausalLM(Model):
empty_cache() empty_cache()
element_size = torch.tensor([], dtype=dtype).element_size() element_size = torch.tensor([], dtype=dtype).element_size()
if SYSTEM == "xpu":
x = 1
else:
x = BLOCK_SIZE // element_size x = BLOCK_SIZE // element_size
if IPEX_AVAIL and SYSTEM == "cpu": if SYSTEM == "ipex" and device == torch.device("cpu"):
self.kv_cache = [ self.kv_cache = [
( (
torch.empty( torch.empty(

View File

@ -15,7 +15,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -34,12 +34,12 @@ class FlashGPT2(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype else:
elif IPEX_AVAIL:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashGPT2 is only available on GPU") raise NotImplementedError("FlashGPT2 is only available on GPU")

View File

@ -17,7 +17,7 @@ from text_generation_server.utils import (
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
class FlashLlama(FlashCausalLM): class FlashLlama(FlashCausalLM):
@ -34,12 +34,12 @@ class FlashLlama(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype else:
elif IPEX_AVAIL:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashLlama is only available on GPU") raise NotImplementedError("FlashLlama is only available on GPU")

View File

@ -16,7 +16,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -38,12 +38,12 @@ class BaseFlashMistral(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype else:
elif IPEX_AVAIL:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashMistral is only available on GPU") raise NotImplementedError("FlashMistral is only available on GPU")

View File

@ -14,7 +14,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -33,12 +33,12 @@ class FlashNeoXSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype else:
elif IPEX_AVAIL:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashNeoX is only available on GPU") raise NotImplementedError("FlashNeoX is only available on GPU")

View File

@ -15,7 +15,7 @@ from text_generation_server.utils import (
weight_files, weight_files,
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -34,12 +34,12 @@ class FlashRWSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype else:
elif IPEX_AVAIL:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashRW is only available on GPU") raise NotImplementedError("FlashRW is only available on GPU")

View File

@ -18,7 +18,7 @@ from text_generation_server.utils import (
Weights, Weights,
) )
from text_generation_server.utils.import_utils import SYSTEM, IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
@ -37,12 +37,12 @@ class FlashSantacoderSharded(FlashCausalLM):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device(f"cuda:{rank}") device = torch.device(f"cuda:{rank}")
dtype = torch.float16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
elif SYSTEM == "xpu": elif SYSTEM == "ipex":
if hasattr(torch, "xpu") and torch.xpu.is_available():
device = torch.device(f"xpu:{rank}") device = torch.device(f"xpu:{rank}")
dtype = torch.float16 if dtype is None else dtype else:
elif IPEX_AVAIL:
device = torch.device("cpu") device = torch.device("cpu")
dtype = torch.bfloat16 if dtype is None else dtype dtype = torch.float16 if dtype is None else dtype
else: else:
raise NotImplementedError("FlashSantacoderSharded is only available on GPU") raise NotImplementedError("FlashSantacoderSharded is only available on GPU")

View File

@ -3,7 +3,7 @@ import torch
from datetime import timedelta from datetime import timedelta
from loguru import logger from loguru import logger
from text_generation_server.utils.import_utils import IPEX_AVAIL from text_generation_server.utils.import_utils import SYSTEM
# Tensor Parallelism settings # Tensor Parallelism settings
RANK = int(os.getenv("RANK", "0")) RANK = int(os.getenv("RANK", "0"))
@ -69,7 +69,7 @@ def initialize_torch_distributed():
if not torch.distributed.is_initialized(): if not torch.distributed.is_initialized():
# Call the init process. # Call the init process.
if IPEX_AVAIL: if SYSTEM == "ipex":
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
ipex.distributed.init_process_group( ipex.distributed.init_process_group(

View File

@ -37,7 +37,10 @@ def get_cpu_free_memory(device, memory_fraction):
return free_memory return free_memory
IPEX_AVAIL = is_ipex_available() def noop(*args, **kwargs):
pass
SYSTEM = None SYSTEM = None
if torch.version.hip is not None: if torch.version.hip is not None:
SYSTEM = "rocm" SYSTEM = "rocm"
@ -49,17 +52,19 @@ elif torch.version.cuda is not None and torch.cuda.is_available():
empty_cache = torch.cuda.empty_cache empty_cache = torch.cuda.empty_cache
synchronize = torch.cuda.synchronize synchronize = torch.cuda.synchronize
get_free_memory = get_cuda_free_memory get_free_memory = get_cuda_free_memory
elif IPEX_AVAIL and hasattr(torch, "xpu") and torch.xpu.is_available(): elif is_ipex_available():
SYSTEM = "xpu" SYSTEM = "ipex"
if hasattr(torch, "xpu") and torch.xpu.is_available():
empty_cache = torch.xpu.empty_cache empty_cache = torch.xpu.empty_cache
synchronize = torch.xpu.synchronize synchronize = torch.xpu.synchronize
get_free_memory = get_xpu_free_memory get_free_memory = get_xpu_free_memory
else:
empty_cache = noop
synchronize = noop
get_free_memory = get_cpu_free_memory
else: else:
SYSTEM = "cpu" SYSTEM = "cpu"
def noop(*args, **kwargs):
pass
empty_cache = noop empty_cache = noop
synchronize = noop synchronize = noop
get_free_memory = get_cpu_free_memory get_free_memory = get_cpu_free_memory